CIRCT 23.0.0git
Loading...
Searching...
No Matches
CutRewriter.h
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines a general cut-based rewriting framework for
10// combinational logic optimization. The framework uses NPN-equivalence matching
11// with area and delay metrics to rewrite cuts (subgraphs) in combinational
12// circuits with optimal patterns.
13//
14//===----------------------------------------------------------------------===//
15
16#ifndef CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
17#define CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
18
20#include "circt/Support/LLVM.h"
22#include "mlir/IR/Operation.h"
23#include "llvm/ADT/APInt.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/PointerIntPair.h"
26#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Allocator.h"
29#include "llvm/Support/LogicalResult.h"
30#include "llvm/Support/raw_ostream.h"
31#include <memory>
32#include <optional>
33#include <utility>
34
35namespace circt {
36namespace synth {
37// Type for representing delays in the circuit. It's user's responsibility to
38// use consistent units, i.e., all delays should be in the same unit (usually
39// femtoseconds, but not limited to it).
40using DelayType = int64_t;
41
42/// Maximum number of inputs supported for truth table generation.
43/// This limit prevents excessive memory usage as truth table size grows
44/// exponentially with the number of inputs (2^n entries).
45static constexpr unsigned maxTruthTableInputs = 16;
46
47// This is a helper function to sort operations topologically in a logic
48// network. This is necessary for cut rewriting to ensure that operations are
49// processed in the correct order, respecting dependencies.
50LogicalResult topologicallySortLogicNetwork(mlir::Operation *op);
51
52// Get the truth table for a specific operation within a block.
53// Block must be a SSACFG or topologically sorted.
54FailureOr<BinaryTruthTable> getTruthTable(ValueRange values, Block *block);
55
56//===----------------------------------------------------------------------===//
57// Cut Data Structures
58//===----------------------------------------------------------------------===//
59
60// Forward declarations
62class CutRewriter;
63class CutEnumerator;
66class LogicNetwork;
67
68//===----------------------------------------------------------------------===//
69// Logic Network Data Structures (Flat IR for efficient cut enumeration)
70//===----------------------------------------------------------------------===//
71
72/// Edge representation in the logic network.
73/// Similar to mockturtle's signal, this encodes both a node index and inversion
74/// in a single 32-bit value. The LSB indicates whether the signal is inverted.
75struct Signal {
76 uint32_t data = 0;
77
78 Signal() = default;
79 Signal(uint32_t index, bool inverted)
80 : data((index << 1) | (inverted ? 1 : 0)) {}
81 explicit Signal(uint32_t raw) : data(raw) {}
82
83 /// Get the node index (without the inversion bit).
84 uint32_t getIndex() const { return data >> 1; }
85
86 /// Check if this edge is inverted.
87 bool isInverted() const { return data & 1; }
88
89 /// Get the raw data (index << 1 | inverted).
90 uint32_t getRaw() const { return data; }
91
92 Signal flipInversion() const { return Signal(getIndex(), !isInverted()); }
93
94 /// Create an inverted version of this edge.
95 Signal operator!() const { return Signal(data ^ 1); }
96
97 bool operator==(const Signal &other) const { return data == other.data; }
98 bool operator!=(const Signal &other) const { return data != other.data; }
99 bool operator<(const Signal &other) const { return data < other.data; }
100};
101
102/// Represents a single gate/node in the flat logic network.
103/// This structure is designed to be cache-friendly and supports up to 3 inputs
104/// (sufficient for AND, XOR, MAJ gates). For nodes with fewer inputs, unused
105/// edges have index 0 (constant 0 node).
106///
107/// Special indices:
108/// - Index 0: Constant 0
109/// - Index 1: Constant 1
110///
111/// It uses 8 bytes for operation pointer + enum, 12 bytes for edges = 20
112/// bytes per gate.
114 /// Kind of logic gate.
115 enum Kind : uint8_t {
116 Constant = 0, ///< Constant 0/1 node (index 0 = const0, index 1 = const1)
117 PrimaryInput = 1, ///< Primary input to the network
118 And2 = 2, ///< AND gate (2-input, aig::AndInverterOp)
119 Xor2 = 3, ///< XOR gate (2-input)
120 Maj3 = 4, ///< Reserved 3-input gate kind
121 Identity = 5, ///< Identity gate (used for 1-input inverter)
122 Choice = 6, ///< Choice node (synth.choice)
123 Dot3 = 7, ///< Ordered DOT gate (3-input, synth.dot)
124 OneHot3 = 8, ///< OneHot gate (3-input, synth.onehot)
125 Mux3 = 9, ///< Ordered MUX gate (3-input, synth.mux_inv)
126 };
127
128 /// Operation pointer and gate kind. Constants have a null operation pointer.
129 Operation *op = nullptr;
131
132 /// Fanin edges (up to 3 inputs). For AND gates, only edges[0] and edges[1]
133 /// are used. For PrimaryInput/Constant and Choice, none are used. The
134 /// inversion bit is encoded in each edge.
136
137 LogicNetworkGate() : op(nullptr), kind(Constant), edges{} {}
139 llvm::ArrayRef<Signal> operands = {})
140 : op(op), kind(kind), edges{} {
141 assert(operands.size() <= 3 && "Too many operands for LogicNetworkGate");
142 for (size_t i = 0; i < operands.size(); ++i)
143 edges[i] = operands[i];
144 }
145
146 /// Get the kind of this gate.
147 Kind getKind() const { return kind; }
148
149 /// Get the operation pointer (nullptr for constants).
150 Operation *getOperation() const { return op; }
151
152 /// Get the number of fanin edges based on kind.
153 unsigned getNumFanins() const {
154 switch (getKind()) {
155 case Constant:
156 case PrimaryInput:
157 return 0;
158 case And2:
159 case Xor2:
160 return 2;
161 case Maj3:
162 return 3;
163 case Dot3:
164 return 3;
165 case OneHot3:
166 return 3;
167 case Mux3:
168 return 3;
169 case Identity:
170 return 1;
171 case Choice:
172 return 0;
173 }
174 llvm_unreachable("Unknown gate kind");
175 }
176
177 /// Check if this is a logic gate that can be part of a cut.
178 bool isLogicGate() const {
179 Kind k = getKind();
180 return k == And2 || k == Xor2 || k == Maj3 || k == Identity ||
181 k == Choice || k == Dot3 || k == OneHot3 || k == Mux3;
182 }
183
184 /// Check if this should always be a cut input (PI or constant).
185 bool isAlwaysCutInput() const {
186 Kind k = getKind();
187 return k == PrimaryInput || k == Constant;
188 }
189};
190
191/// Flat logic network representation for efficient cut enumeration.
192///
193/// This class provides a mockturtle-style flat representation of the
194/// combinational logic network. Each value in the MLIR IR is assigned a unique
195/// index, and gates are stored in a contiguous vector for cache efficiency.
196///
197/// The network supports:
198/// - O(1) lookup of gate information by index
199/// - Compact representation with inversion encoded in edges
200/// - Efficient simulation and truth table computation
201///
202/// Special reserved indices:
203/// - Index 0: Constant 0
204/// - Index 1: Constant 1
206public:
207 /// Special constant indices.
208 static constexpr uint32_t kConstant0 = 0;
209 static constexpr uint32_t kConstant1 = 1;
210
211 ArrayRef<LogicNetworkGate> getGates() const { return gates; }
212
214 // Reserve index 0 for constant 0 and index 1 for constant 1
215 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
216 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
217 // indexToValue needs placeholders for constants
218 indexToValue.push_back(Value()); // const0
219 indexToValue.push_back(Value()); // const1
220 }
221
222 /// Get a LogicEdge representing constant 0.
223 static Signal getConstant0() { return Signal(kConstant0, false); }
224
225 /// Get a LogicEdge representing constant 1 (constant 0 inverted).
226 static Signal getConstant1() { return Signal(kConstant0, true); }
227
228 /// Get or create an index for a value.
229 /// If the value doesn't have an index yet, assigns one and returns the index.
230 uint32_t getOrCreateIndex(Value value);
231
232 /// Get the raw index for a value. Asserts if value is not found.
233 /// Note: This returns only the index, not a Signal with inversion info.
234 /// Use hasIndex() to check existence first, or use getOrCreateIndex().
235 uint32_t getIndex(Value value) const;
236
237 /// Check if a value has been indexed.
238 bool hasIndex(Value value) const;
239
240 /// Get the value for a given raw index. Asserts if index is out of bounds.
241 /// Returns null Value for constant indices (0 and 1).
242 Value getValue(uint32_t index) const;
243
244 /// Fill values for the given raw indices.
245 void getValues(ArrayRef<uint32_t> indices,
246 SmallVectorImpl<Value> &values) const;
247
248 /// Get a Signal for a value.
249 /// Asserts if value not found - use hasIndex() first if unsure.
250 Signal getSignal(Value value, bool inverted) const {
251 return Signal(getIndex(value), inverted);
252 }
253
254 /// Get or create a Signal for a value.
255 Signal getOrCreateSignal(Value value, bool inverted) {
256 return Signal(getOrCreateIndex(value), inverted);
257 }
258
259 /// Get the value for a given Signal (extracts index from Signal).
260 Value getValue(Signal signal) const { return getValue(signal.getIndex()); }
261
262 /// Get the gate at a given index.
263 const LogicNetworkGate &getGate(uint32_t index) const { return gates[index]; }
264
265 /// Get mutable reference to gate at index.
266 LogicNetworkGate &getGate(uint32_t index) { return gates[index]; }
267
268 /// Get the total number of nodes in the network.
269 size_t size() const { return gates.size(); }
270
271 /// Add a primary input to the network.
272 uint32_t addPrimaryInput(Value value);
273
274 /// Add a gate with explicit result value and operand signals.
275 uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result,
276 llvm::ArrayRef<Signal> operands = {});
277
278 /// Add a gate using op->getResult(0) as the result value.
279 uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind,
280 llvm::ArrayRef<Signal> operands = {}) {
281 return addGate(op, kind, op->getResult(0), operands);
282 }
283
284 /// Build the logic network from a region/block in topological order.
285 /// Returns failure if the IR is not in a valid form.
286 LogicalResult buildFromBlock(Block *block);
287
288 /// Clear the network and reset to initial state.
289 void clear();
290
291private:
292 /// Map from MLIR Value to network index.
293 llvm::DenseMap<Value, uint32_t> valueToIndex;
294
295 /// Map from network index to MLIR Value.
296 llvm::SmallVector<Value> indexToValue;
297
298 /// Vector of all gates in the network.
299 llvm::SmallVector<LogicNetworkGate> gates;
300};
301
302/// Result of matching a cut against a pattern.
303///
304/// This structure contains the area and per-input delay information
305/// computed during pattern matching.
306///
307/// The delays can be stored in two ways:
308/// 1. As a reference to static/cached data (e.g., tech library delays)
309/// - Use setDelayRef() for zero-cost reference (no allocation)
310/// 2. As owned dynamic data (e.g., computed SOP delays)
311/// - Use setOwnedDelays() to transfer ownership
312///
314 /// Area cost of implementing this cut with the pattern.
315 double area = 0.0;
316
317 /// Default constructor.
318 MatchResult() = default;
319
320 /// Constructor with area and delays (by reference).
321 MatchResult(double area, ArrayRef<DelayType> delays)
322 : area(area), borrowedDelays(delays) {}
323
324 /// Set delays by reference (zero-cost for static/cached delays).
325 /// The caller must ensure the referenced data remains valid.
326 void setDelayRef(ArrayRef<DelayType> delays) { borrowedDelays = delays; }
327
328 /// Set delays by transferring ownership (for dynamically computed delays).
329 /// This moves the data into internal storage.
330 void setOwnedDelays(SmallVector<DelayType, 6> delays) {
331 ownedDelays.emplace(std::move(delays));
332 borrowedDelays = {};
333 }
334
335 /// Get all delays as an ArrayRef.
336 ArrayRef<DelayType> getDelays() const {
337 return ownedDelays.has_value() ? ArrayRef<DelayType>(*ownedDelays)
339 }
340
341private:
342 /// Borrowed delays (used when ownedDelays is empty).
343 /// Points to external data provided via setDelayRef().
344 ArrayRef<DelayType> borrowedDelays;
345
346 /// Owned delays (used when present).
347 /// Only allocated when setOwnedDelays() is called. When empty (std::nullopt),
348 /// moving this MatchResult avoids constructing/moving the SmallVector,
349 /// achieving zero-cost abstraction for the common case (borrowed delays).
350 std::optional<SmallVector<DelayType, 6>> ownedDelays;
351};
352
353/// Represents a cut that has been successfully matched to a rewriting pattern.
354///
355/// This class encapsulates the result of matching a cut against a rewriting
356/// pattern during optimization. It stores the matched pattern, the
357/// cut that was matched, and timing information needed for optimization.
359private:
360 const CutRewritePattern *pattern = nullptr; ///< The matched library pattern
361 SmallVector<DelayType, 1>
362 arrivalTimes; ///< Arrival times of outputs from this pattern
363 double area = 0.0; ///< Area cost of this pattern
364
365public:
366 /// Default constructor creates an invalid matched pattern.
367 MatchedPattern() = default;
368
369 /// Constructor for a valid matched pattern.
371 SmallVector<DelayType, 1> arrivalTimes, double area)
372 : pattern(pattern), arrivalTimes(std::move(arrivalTimes)), area(area) {}
373
374 /// Get the arrival time of signals through this pattern.
375 DelayType getArrivalTime(unsigned outputIndex) const;
376 ArrayRef<DelayType> getArrivalTimes() const;
377
378 /// Get the library pattern that was matched.
379 const CutRewritePattern *getPattern() const;
380
381 /// Get the area cost of using this pattern.
382 double getArea() const;
383};
384
385/// Represents a cut in the combinational logic network.
386///
387/// A cut is a subset of nodes in the combinational logic that forms a complete
388/// subgraph with a single output. It represents a portion of the circuit that
389/// can potentially be replaced with a single library gate or pattern.
390///
391/// The cut contains:
392/// - Input values: The boundary between the cut and the rest of the circuit
393/// - Operations: The logic operations within the cut boundary
394/// - Root operation: The output-driving operation of the cut
395///
396/// Cuts are used in combinational logic optimization to identify regions that
397/// can be optimized and replaced with more efficient implementations.
398class Cut {
399 /// Cached truth table for this cut.
400 /// Computed lazily when first accessed to avoid unnecessary computation.
401 mutable std::optional<BinaryTruthTable> truthTable;
402
403 /// Cached NPN canonical form for this cut.
404 /// Computed lazily from the truth table when first accessed.
405 mutable std::optional<NPNClass> npnClass;
406
407 std::optional<MatchedPattern> matchedPattern;
408
409 /// Root index in LogicNetwork (0 indicates no root for a trivial cut).
410 /// The root node produces the output of the cut.
411 uint32_t rootIndex = 0;
412
413 /// Signature bitset for fast cut size estimation.
414 /// Bit i is set if value with index i is in the cut's inputs.
415 /// This enables O(1) estimation of merged cut size using popcount.
416 uint64_t signature = 0;
417
418 /// Operand cuts used to create this cut (for lazy TT computation).
419 /// Stored to enable fast incremental truth table computation after
420 /// duplicate removal. Using raw pointers is safe since cuts are allocated
421 /// via bump allocator and live for the duration of enumeration.
422 llvm::SmallVector<const Cut *, 3> operandCuts;
423
424public:
425 Cut() = default;
426 Cut(uint32_t rootIndex, ArrayRef<uint32_t> inputs, uint64_t signature,
427 ArrayRef<const Cut *> operandCuts = {},
428 std::optional<BinaryTruthTable> truthTable = std::nullopt)
431 operandCuts(operandCuts.begin(), operandCuts.end()),
432 inputs(inputs.begin(), inputs.end()) {}
433
434 /// Create a trivial cut for a value.
435 static Cut getTrivialCut(uint32_t index);
436
437 /// External inputs to this cut (cut boundary).
438 /// Stored as LogicNetwork indices for efficient operations.
439 llvm::SmallVector<uint32_t, 6> inputs;
440
441 /// Check if this cut represents a trivial cut.
442 /// A trivial cut has no root operation and exactly one input.
443 bool isTrivialCut() const;
444
445 /// Get the root index in the LogicNetwork.
446 uint32_t getRootIndex() const { return rootIndex; }
447
448 /// Set the root index of this cut.
449 void setRootIndex(uint32_t idx) { rootIndex = idx; }
450
451 /// Get the signature of this cut.
452 uint64_t getSignature() const { return signature; }
453
454 /// Set the signature of this cut.
455 void setSignature(uint64_t sig) { signature = sig; }
456
457 /// Check if this cut dominates another (i.e., this cut's inputs are a subset
458 /// of the other's inputs). Uses signature pre-filtering for speed.
459 /// Both cuts must have sorted inputs.
460 bool dominates(const Cut &other) const;
461
462 /// Check if this cut dominates a set of sorted inputs with the given
463 /// signature.
464 bool dominates(ArrayRef<uint32_t> otherInputs, uint64_t otherSig) const;
465
466 void dump(llvm::raw_ostream &os, const LogicNetwork &network) const;
467
468 /// Get the number of inputs to this cut.
469 unsigned getInputSize() const;
470
471 /// Get the number of outputs from root operation.
472 unsigned getOutputSize(const LogicNetwork &network) const;
473
474 /// Get the truth table for this cut.
475 /// The truth table represents the boolean function computed by this cut.
476 const std::optional<BinaryTruthTable> &getTruthTable() const {
477 return truthTable;
478 }
479
480 /// Compute truth table using fast incremental method from operand cuts.
481 /// Trivial cuts are handled directly; non-trivial cuts require that
482 /// operand cuts have already been set via setOperandCuts.
483 void computeTruthTableFromOperands(const LogicNetwork &network);
484
485 /// Set the truth table directly (used for incremental computation).
486 void setTruthTable(BinaryTruthTable tt) { truthTable.emplace(std::move(tt)); }
487
488 /// Set operand cuts for lazy truth table computation.
489 void setOperandCuts(ArrayRef<const Cut *> cuts) {
490 operandCuts.assign(cuts.begin(), cuts.end());
491 }
492
493 /// Get operand cuts (for fast TT computation).
494 ArrayRef<const Cut *> getOperandCuts() const { return operandCuts; }
495
496 /// Get the NPN canonical form for this cut.
497 /// This is used for efficient pattern matching against library components.
498 const NPNClass &getNPNClass() const;
499 const NPNClass &getNPNClass(const NPNTable *npnTable) const;
500
501 /// Get the permutated inputs for this cut based on the given pattern NPN.
502 /// Returns indices into the inputs vector.
503 void
504 getPermutatedInputIndices(const NPNTable *npnTable,
505 const NPNClass &patternNPN,
506 SmallVectorImpl<unsigned> &permutedIndices) const;
507
508 /// Get arrival times for each input of this cut.
509 /// Returns failure if any input doesn't have a valid matched pattern.
510 LogicalResult getInputArrivalTimes(CutEnumerator &enumerator,
511 SmallVectorImpl<DelayType> &results) const;
512
513 /// Matched pattern for this cut.
517
518 /// Get the matched pattern for this cut.
519 const std::optional<MatchedPattern> &getMatchedPattern() const {
520 return matchedPattern;
521 }
522};
523
524/// Manages a collection of cuts for a single logic node using priority cuts
525/// algorithm.
526///
527/// Each node in the combinational logic network can have multiple cuts
528/// representing different ways to group it with surrounding logic. The CutSet
529/// manages these cuts and selects the best one based on the optimization
530/// strategy (area or timing).
531///
532/// The priority cuts algorithm maintains a bounded set of the most promising
533/// cuts to avoid exponential explosion while ensuring good optimization
534/// results.
535class CutSet {
536private:
537 llvm::SmallVector<Cut *, 12> cuts; ///< Collection of cuts for this node
538 Cut *bestCut = nullptr;
539 bool isFrozen = false; ///< Whether cut set is finalized
540
541public:
542 /// Check if this cut set has a valid matched pattern.
543 bool isMatched() const { return bestCut; }
544
545 /// Get the cut associated with the best matched pattern.
546 Cut *getBestMatchedCut() const;
547
548 /// Finalize the cut set by removing duplicates and selecting the best
549 /// pattern.
550 void finalize(
551 const CutRewriterOptions &options,
552 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut,
553 const LogicNetwork &logicNetwork);
554
555 /// Get the number of cuts in this set.
556 unsigned size() const;
557
558 /// Add a new cut to this set using bump allocator.
559 /// NOTE: The cut set must not be frozen
560 void addCut(Cut *cut);
561
562 /// Get read-only access to all cuts in this set.
563 ArrayRef<Cut *> getCuts() const;
564};
565
566/// Configuration options for the cut-based rewriting algorithm.
567///
568/// These options control various aspects of the rewriting process including
569/// optimization strategy, resource limits, and algorithmic parameters.
571 /// Optimization strategy (area vs. timing).
573
574 /// Maximum number of inputs allowed for any cut.
575 /// Larger cuts provide more optimization opportunities but increase
576 /// computational complexity exponentially.
578
579 /// Maximum number of cuts to maintain per logic node.
580 /// The priority cuts algorithm keeps only the most promising cuts
581 /// to prevent exponential explosion.
583
584 /// Fail if there is a root operation that has no matching pattern.
585 bool allowNoMatch = false;
586
587 /// Put arrival times to rewritten operations.
588 bool attachDebugTiming = false;
589
590 /// Run priority cuts enumeration and dump the cut sets.
591 bool testPriorityCuts = false;
592
593 /// Optional lookup table used to accelerate 4-input NPN canonicalization.
594 const NPNTable *npnTable = nullptr;
595};
596
597//===----------------------------------------------------------------------===//
598// Cut Enumeration Engine
599//===----------------------------------------------------------------------===//
600
602 uint64_t numCutsCreated = 0;
603 uint64_t numCutSetsCreated = 0;
604 uint64_t numCutsRewritten = 0;
605};
606
607template <typename T>
609public:
612
613 template <typename... Args>
614 T *create(Args &&...args) {
616 return new (allocator.Allocate()) T(std::forward<Args>(args)...);
617 }
618
619 void DestroyAll() { allocator.DestroyAll(); }
620
621private:
622 llvm::SpecificBumpPtrAllocator<T> allocator;
624};
625/// Cut enumeration engine for combinational logic networks.
626///
627/// The CutEnumerator is responsible for generating cuts for each node in a
628/// combinational logic network. It uses a priority cuts algorithm to maintain a
629/// bounded set of promising cuts while avoiding exponential explosion.
630///
631/// The enumeration process works by:
632/// 1. Visiting nodes in topological order
633/// 2. For each node, combining cuts from its inputs
634/// 3. Matching generated cuts against available patterns
635/// 4. Maintaining only the most promising cuts per node
637public:
638 /// Constructor for cut enumerator.
640
641 /// Enumerate cuts for all nodes in the given module.
642 ///
643 /// This is the main entry point that orchestrates the cut enumeration
644 /// process. It visits all operations in the module and generates cuts
645 /// for combinational logic operations.
646 LogicalResult enumerateCuts(
647 Operation *topOp,
648 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut =
649 [](const Cut &) { return std::nullopt; });
650
651 /// Create a new cut set for an index.
652 /// The index must not already have a cut set.
653 CutSet *createNewCutSet(uint32_t index);
654
655 /// Get the cut set for a specific index.
656 /// If not found, it means no cuts have been generated for this value yet.
657 /// In that case return a trivial cut set.
658 const CutSet *getCutSet(uint32_t index);
659
660 /// Clear all cut sets and reset the enumerator.
661 void clear();
662
663 /// Get the cut rewriter options used for this enumeration.
664 const CutRewriterOptions &getOptions() const { return options; }
665
666 /// Record that one cut was successfully rewritten.
668
669 void dump() const;
670
671 /// Get cut sets (indexed by LogicNetwork index).
672 const llvm::DenseMap<uint32_t, CutSet *> &getCutSets() const {
673 return cutSets;
674 }
675
676 /// Get the processing order.
677 ArrayRef<uint32_t> getProcessingOrder() const { return processingOrder; }
678
679private:
680 /// Visit a combinational logic operation and generate cuts.
681 /// This handles the core cut enumeration logic for operations
682 /// like AND, OR, XOR, etc.
683 LogicalResult visitLogicOp(uint32_t nodeIndex);
684
685 /// Maps indices to their associated cut sets.
686 /// CutSets are allocated from the bump allocator.
687 llvm::DenseMap<uint32_t, CutSet *> cutSets;
688
689 /// Typed bump allocators for fast allocation with destructors.
692
693 /// Indices in processing order.
694 llvm::SmallVector<uint32_t> processingOrder;
695
696 /// Configuration options for cut enumeration.
698
699 /// Function to match cuts against available patterns.
700 /// Set during enumeration and used when finalizing cut sets.
701 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut;
702
703 /// Flat logic network representation used during enumeration/rewrite.
705
706 /// Statistics for cut enumeration (number of cuts allocated, etc.).
708
709public:
710 /// Get the logic network (read-only).
711 const LogicNetwork &getLogicNetwork() const { return logicNetwork; }
712
713 /// Get the logic network (mutable).
715
716 /// Get enumeration statistics.
717 const CutEnumeratorStats &getStats() const { return stats; }
718};
719
720/// Base class for cut rewriting patterns used in combinational logic
721/// optimization.
722///
723/// A CutRewritePattern represents a library component or optimization pattern
724/// that can replace cuts in the combinational logic network. Each pattern
725/// defines:
726/// - How to recognize matching cuts and compute area/delay metrics
727/// - How to transform/replace the matched cuts
728///
729/// Patterns can use truth table matching for efficient recognition or
730/// implement custom matching logic for more complex cases.
732 CutRewritePattern(mlir::MLIRContext *context) : context(context) {}
733 /// Virtual destructor for base class.
734 virtual ~CutRewritePattern() = default;
735
736 /// Check if a cut matches this pattern and compute area/delay metrics.
737 ///
738 /// This method is called to determine if a cut can be replaced by this
739 /// pattern. If the cut matches, it should return a MatchResult containing
740 /// the area and per-input delays for this specific cut.
741 ///
742 /// If useTruthTableMatcher() returns true, this method is only
743 /// called for cuts with matching truth tables.
744 virtual std::optional<MatchResult> match(CutEnumerator &enumerator,
745 const Cut &cut) const = 0;
746
747 /// Specify truth tables that this pattern can match.
748 ///
749 /// If this method returns true, the pattern matcher will use truth table
750 /// comparison for efficient pre-filtering. Only cuts with matching truth
751 /// tables will be passed to the match() method. If it returns false, the
752 /// pattern will be checked against all cuts regardless of their truth tables.
753 /// This is useful for patterns that match regardless of their truth tables,
754 /// such as LUT-based patterns.
755 virtual bool
756 useTruthTableMatcher(SmallVectorImpl<NPNClass> &matchingNPNClasses) const;
757
758 /// Return a new operation that replaces the matched cut.
759 ///
760 /// Unlike MLIR's RewritePattern framework which allows arbitrary in-place
761 /// modifications, this method creates a new operation to replace the matched
762 /// cut rather than modifying existing operations. This constraint exists
763 /// because the cut enumerator maintains references to operations throughout
764 /// the circuit, making it safe to only replace the root operation of each
765 /// cut while preserving all other operations unchanged.
766 virtual FailureOr<Operation *> rewrite(mlir::OpBuilder &builder,
767 CutEnumerator &enumerator,
768 const Cut &cut) const = 0;
769
770 /// Get the number of outputs this pattern produces.
771 virtual unsigned getNumOutputs() const = 0;
772
773 /// Get the name of this pattern. Used for debugging.
774 virtual StringRef getPatternName() const { return "<unnamed>"; }
775
776 /// Get location for this pattern(optional).
777 virtual LocationAttr getLoc() const { return mlir::UnknownLoc::get(context); }
778
779 mlir::MLIRContext *getContext() const { return context; }
780
781private:
782 mlir::MLIRContext *context;
783};
784
785/// Manages a collection of rewriting patterns for combinational logic
786/// optimization.
787///
788/// This class organizes and provides efficient access to rewriting patterns
789/// used during cut-based optimization. It maintains:
790/// - A collection of all available patterns
791/// - Fast lookup tables for truth table-based matching
792/// - Separation of truth table vs. custom matching patterns
793///
794/// The pattern set is used by the CutRewriter to find suitable replacements
795/// for cuts in the combinational logic network.
797public:
798 /// Constructor that takes ownership of the provided patterns.
799 ///
800 /// During construction, patterns are analyzed and organized for efficient
801 /// lookup. Truth table matchers are indexed by their NPN canonical forms.
803 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns);
804
805private:
806 /// Owned collection of all rewriting patterns.
807 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns;
808
809 /// Fast lookup table mapping NPN canonical forms to matching patterns.
810 /// Each entry maps a truth table and input size to patterns that can handle
811 /// it.
812 DenseMap<std::pair<APInt, unsigned>,
813 SmallVector<std::pair<NPNClass, const CutRewritePattern *>>>
815
816 /// Patterns that use custom matching logic instead of NPN lookup.
817 /// These patterns are checked against every cut.
818 SmallVector<const CutRewritePattern *, 4> nonNPNPatterns;
819
820 /// CutRewriter needs access to internal data structures for pattern matching.
821 friend class CutRewriter;
822};
823
824/// Main cut-based rewriting algorithm for combinational logic optimization.
825///
826/// The CutRewriter implements a cut-based rewriting algorithm that:
827/// 1. Enumerates cuts in the combinational logic network using a priority cuts
828/// algorithm
829/// 2. Matches cuts against available rewriting patterns
830/// 3. Selects optimal patterns based on area or timing objectives
831/// 4. Rewrites the circuit using the selected patterns
832///
833/// The algorithm processes the network in topological order, building up cut
834/// sets for each node and selecting the best implementation based on the
835/// specified optimization strategy.
836///
837/// Usage example:
838/// ```cpp
839/// CutRewriterOptions options;
840/// options.strategy = OptimizationStrategy::Area;
841/// options.maxCutInputSize = 4;
842/// options.maxCutSizePerRoot = 8;
843///
844/// CutRewritePatternSet patterns(std::move(optimizationPatterns));
845/// CutRewriter rewriter(module, options, patterns);
846///
847/// if (failed(rewriter.run())) {
848/// // Handle rewriting failure
849/// }
850/// ```
852public:
853 /// Constructor for the cut rewriter.
856
857 /// Execute the complete cut-based rewriting algorithm.
858 ///
859 /// This method orchestrates the entire rewriting process:
860 /// 1. Enumerate cuts for all nodes in the combinational logic
861 /// 2. Match cuts against available patterns
862 /// 3. Select optimal patterns based on strategy
863 /// 4. Rewrite the circuit with selected patterns
864 LogicalResult run(Operation *topOp);
865
867 return cutEnumerator.getStats();
868 }
869
870private:
871 /// Enumerate cuts for all nodes in the given module.
872 /// Note: This preserves module boundaries and does not perform
873 /// rewriting across the hierarchy.
874 LogicalResult enumerateCuts(Operation *topOp);
875
876 /// Find patterns that match a cut's truth table.
877 ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
878 getMatchingPatternsFromTruthTable(const Cut &cut) const;
879
880 /// Match a cut against available patterns and compute arrival time.
881 std::optional<MatchedPattern> patternMatchCut(const Cut &cut);
882
883 /// Perform the actual circuit rewriting using selected patterns.
884 LogicalResult runBottomUpRewrite(Operation *topOp);
885
886 /// Configuration options
888
889 /// Available rewriting patterns
891
893};
894
895} // namespace synth
896} // namespace circt
897
898#endif // CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
assert(baseType &&"element must be base type")
RewritePatternSet pattern
Precomputed NPN canonicalization table for 4-input single-output functions.
Definition TruthTable.h:168
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
LogicNetwork & getLogicNetwork()
Get the logic network (mutable).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
TrackedSpecificBumpPtrAllocator< CutSet > cutSetAllocator
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
CutEnumeratorStats stats
Statistics for cut enumeration (number of cuts allocated, etc.).
void noteCutRewritten()
Record that one cut was successfully rewritten.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
const CutEnumeratorStats & getStats() const
Get enumeration statistics.
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
TrackedSpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
const CutRewriterOptions & getOptions() const
Get the cut rewriter options used for this enumeration.
Manages a collection of rewriting patterns for combinational logic optimization.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
Main cut-based rewriting algorithm for combinational logic optimization.
const CutEnumeratorStats & getStats() const
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
const CutRewritePatternSet & patterns
Available rewriting patterns.
CutEnumerator cutEnumerator
CutRewriter(const CutRewriterOptions &options, CutRewritePatternSet &patterns)
Constructor for the cut rewriter.
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
bool isMatched() const
Check if this cut set has a valid matched pattern.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
static Cut getTrivialCut(uint32_t index)
Create a trivial cut for a value.
void setRootIndex(uint32_t idx)
Set the root index of this cut.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
uint64_t signature
Signature bitset for fast cut size estimation.
uint64_t getSignature() const
Get the signature of this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
Cut(uint32_t rootIndex, ArrayRef< uint32_t > inputs, uint64_t signature, ArrayRef< const Cut * > operandCuts={}, std::optional< BinaryTruthTable > truthTable=std::nullopt)
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
void setOperandCuts(ArrayRef< const Cut * > cuts)
Set operand cuts for lazy truth table computation.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
void setSignature(uint64_t sig)
Set the signature of this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another (i.e., this cut's inputs are a subset of the other's inputs).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
unsigned getInputSize() const
Get the number of inputs to this cut.
ArrayRef< const Cut * > getOperandCuts() const
Get operand cuts (for fast TT computation).
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
Value getValue(Signal signal) const
Get the value for a given Signal (extracts index from Signal).
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
LogicNetworkGate & getGate(uint32_t index)
Get mutable reference to gate at index.
Signal getSignal(Value value, bool inverted) const
Get a Signal for a value.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, llvm::ArrayRef< Signal > operands={})
Add a gate using op->getResult(0) as the result value.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
static Signal getConstant0()
Get a LogicEdge representing constant 0.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
size_t size() const
Get the total number of nodes in the network.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
static Signal getConstant1()
Get a LogicEdge representing constant 1 (constant 0 inverted).
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
MatchedPattern(const CutRewritePattern *pattern, SmallVector< DelayType, 1 > arrivalTimes, double area)
Constructor for a valid matched pattern.
double getArea() const
Get the area cost of using this pattern.
MatchedPattern()=default
Default constructor creates an invalid matched pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
llvm::SpecificBumpPtrAllocator< T > allocator
TrackedSpecificBumpPtrAllocator(uint64_t &allocationCount)
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
int64_t DelayType
Definition CutRewriter.h:40
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:45
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
Represents a boolean function as a truth table.
Definition TruthTable.h:41
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:106
Base class for cut rewriting patterns used in combinational logic optimization.
virtual StringRef getPatternName() const
Get the name of this pattern. Used for debugging.
virtual FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const =0
Return a new operation that replaces the matched cut.
virtual ~CutRewritePattern()=default
Virtual destructor for base class.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
virtual std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const =0
Check if a cut matches this pattern and compute area/delay metrics.
mlir::MLIRContext * getContext() const
mlir::MLIRContext * context
CutRewritePattern(mlir::MLIRContext *context)
virtual unsigned getNumOutputs() const =0
Get the number of outputs this pattern produces.
virtual LocationAttr getLoc() const
Get location for this pattern(optional).
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
LogicNetworkGate(Operation *op, Kind kind, llvm::ArrayRef< Signal > operands={})
unsigned getNumFanins() const
Get the number of fanin edges based on kind.
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
bool isLogicGate() const
Check if this is a logic gate that can be part of a cut.
Operation * op
Operation pointer and gate kind. Constants have a null operation pointer.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Dot3
Ordered DOT gate (3-input, synth.dot)
@ OneHot3
OneHot gate (3-input, synth.onehot)
@ Identity
Identity gate (used for 1-input inverter)
@ Mux3
Ordered MUX gate (3-input, synth.mux_inv)
@ Maj3
Reserved 3-input gate kind.
@ PrimaryInput
Primary input to the network.
@ Choice
Choice node (synth.choice)
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
void setDelayRef(ArrayRef< DelayType > delays)
Set delays by reference (zero-cost for static/cached delays).
MatchResult()=default
Default constructor.
void setOwnedDelays(SmallVector< DelayType, 6 > delays)
Set delays by transferring ownership (for dynamically computed delays).
double area
Area cost of implementing this cut with the pattern.
ArrayRef< DelayType > borrowedDelays
Borrowed delays (used when ownedDelays is empty).
std::optional< SmallVector< DelayType, 6 > > ownedDelays
Owned delays (used when present).
MatchResult(double area, ArrayRef< DelayType > delays)
Constructor with area and delays (by reference).
ArrayRef< DelayType > getDelays() const
Get all delays as an ArrayRef.
Edge representation in the logic network.
Definition CutRewriter.h:75
uint32_t getIndex() const
Get the node index (without the inversion bit).
Definition CutRewriter.h:84
Signal operator!() const
Create an inverted version of this edge.
Definition CutRewriter.h:95
bool operator<(const Signal &other) const
Definition CutRewriter.h:99
bool operator!=(const Signal &other) const
Definition CutRewriter.h:98
Signal flipInversion() const
Definition CutRewriter.h:92
uint32_t getRaw() const
Get the raw data (index << 1 | inverted).
Definition CutRewriter.h:90
bool isInverted() const
Check if this edge is inverted.
Definition CutRewriter.h:87
bool operator==(const Signal &other) const
Definition CutRewriter.h:97
Signal(uint32_t index, bool inverted)
Definition CutRewriter.h:79
Signal(uint32_t raw)
Definition CutRewriter.h:81