CIRCT 23.0.0git
Loading...
Searching...
No Matches
CutRewriter.h
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines a general cut-based rewriting framework for
10// combinational logic optimization. The framework uses NPN-equivalence matching
11// with area and delay metrics to rewrite cuts (subgraphs) in combinational
12// circuits with optimal patterns.
13//
14//===----------------------------------------------------------------------===//
15
16#ifndef CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
17#define CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
18
20#include "circt/Support/LLVM.h"
22#include "mlir/IR/Operation.h"
23#include "llvm/ADT/APInt.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/PointerIntPair.h"
26#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Allocator.h"
29#include "llvm/Support/LogicalResult.h"
30#include "llvm/Support/raw_ostream.h"
31#include <memory>
32#include <optional>
33#include <utility>
34
35namespace circt {
36namespace synth {
37// Type for representing delays in the circuit. It's user's responsibility to
38// use consistent units, i.e., all delays should be in the same unit (usually
39// femtoseconds, but not limited to it).
40using DelayType = int64_t;
41
42/// Maximum number of inputs supported for truth table generation.
43/// This limit prevents excessive memory usage as truth table size grows
44/// exponentially with the number of inputs (2^n entries).
45static constexpr unsigned maxTruthTableInputs = 16;
46
47// This is a helper function to sort operations topologically in a logic
48// network. This is necessary for cut rewriting to ensure that operations are
49// processed in the correct order, respecting dependencies.
50LogicalResult topologicallySortLogicNetwork(mlir::Operation *op);
51
52// Get the truth table for a specific operation within a block.
53// Block must be a SSACFG or topologically sorted.
54FailureOr<BinaryTruthTable> getTruthTable(ValueRange values, Block *block);
55
56//===----------------------------------------------------------------------===//
57// Cut Data Structures
58//===----------------------------------------------------------------------===//
59
60// Forward declarations
62class CutRewriter;
63class CutEnumerator;
66class LogicNetwork;
67
68//===----------------------------------------------------------------------===//
69// Logic Network Data Structures (Flat IR for efficient cut enumeration)
70//===----------------------------------------------------------------------===//
71
72/// Edge representation in the logic network.
73/// Similar to mockturtle's signal, this encodes both a node index and inversion
74/// in a single 32-bit value. The LSB indicates whether the signal is inverted.
75struct Signal {
76 uint32_t data = 0;
77
78 Signal() = default;
79 Signal(uint32_t index, bool inverted)
80 : data((index << 1) | (inverted ? 1 : 0)) {}
81 explicit Signal(uint32_t raw) : data(raw) {}
82
83 /// Get the node index (without the inversion bit).
84 uint32_t getIndex() const { return data >> 1; }
85
86 /// Check if this edge is inverted.
87 bool isInverted() const { return data & 1; }
88
89 /// Get the raw data (index << 1 | inverted).
90 uint32_t getRaw() const { return data; }
91
92 Signal flipInversion() const { return Signal(getIndex(), !isInverted()); }
93
94 /// Create an inverted version of this edge.
95 Signal operator!() const { return Signal(data ^ 1); }
96
97 bool operator==(const Signal &other) const { return data == other.data; }
98 bool operator!=(const Signal &other) const { return data != other.data; }
99 bool operator<(const Signal &other) const { return data < other.data; }
100};
101
102/// Represents a single gate/node in the flat logic network.
103/// This structure is designed to be cache-friendly and supports up to 3 inputs
104/// (sufficient for AND, XOR, MAJ gates). For nodes with fewer inputs, unused
105/// edges have index 0 (constant 0 node).
106///
107/// Special indices:
108/// - Index 0: Constant 0
109/// - Index 1: Constant 1
110///
111/// It uses 8 bytes for operation pointer + enum, 12 bytes for edges = 20
112/// bytes per gate.
114 /// Kind of logic gate.
115 enum Kind : uint8_t {
116 Constant = 0, ///< Constant 0/1 node (index 0 = const0, index 1 = const1)
117 PrimaryInput = 1, ///< Primary input to the network
118 And2 = 2, ///< AND gate (2-input, aig::AndInverterOp)
119 Xor2 = 3, ///< XOR gate (2-input)
120 Maj3 = 4, ///< Reserved 3-input gate kind
121 Identity = 5, ///< Identity gate (used for 1-input inverter)
122 Choice = 6, ///< Choice node (synth.choice)
123 Dot3 = 7 ///< Ordered DOT gate (3-input, synth.dot)
124 };
125
126 /// Operation pointer and kind packed together.
127 /// The kind is stored in the low bits of the pointer.
128 llvm::PointerIntPair<Operation *, 3, Kind> opAndKind;
129
130 /// Fanin edges (up to 3 inputs). For AND gates, only edges[0] and edges[1]
131 /// are used. For PrimaryInput/Constant and Choice, none are used. The
132 /// inversion bit is encoded in each edge.
134
136 LogicNetworkGate(Operation *op, Kind kind,
137 llvm::ArrayRef<Signal> operands = {})
138 : opAndKind(op, kind), edges{} {
139 assert(operands.size() <= 3 && "Too many operands for LogicNetworkGate");
140 for (size_t i = 0; i < operands.size(); ++i)
141 edges[i] = operands[i];
142 }
143
144 /// Get the kind of this gate.
145 Kind getKind() const { return opAndKind.getInt(); }
146
147 /// Get the operation pointer (nullptr for constants).
148 Operation *getOperation() const { return opAndKind.getPointer(); }
149
150 /// Get the number of fanin edges based on kind.
151 unsigned getNumFanins() const {
152 switch (getKind()) {
153 case Constant:
154 case PrimaryInput:
155 return 0;
156 case And2:
157 case Xor2:
158 return 2;
159 case Maj3:
160 case Dot3:
161 return 3;
162 case Identity:
163 return 1;
164 case Choice:
165 return 0;
166 }
167 llvm_unreachable("Unknown gate kind");
168 }
169
170 /// Check if this is a logic gate that can be part of a cut.
171 bool isLogicGate() const {
172 Kind k = getKind();
173 return k == And2 || k == Xor2 || k == Maj3 || k == Identity ||
174 k == Choice || k == Dot3;
175 }
176
177 /// Check if this should always be a cut input (PI or constant).
178 bool isAlwaysCutInput() const {
179 Kind k = getKind();
180 return k == PrimaryInput || k == Constant;
181 }
182};
183
184/// Flat logic network representation for efficient cut enumeration.
185///
186/// This class provides a mockturtle-style flat representation of the
187/// combinational logic network. Each value in the MLIR IR is assigned a unique
188/// index, and gates are stored in a contiguous vector for cache efficiency.
189///
190/// The network supports:
191/// - O(1) lookup of gate information by index
192/// - Compact representation with inversion encoded in edges
193/// - Efficient simulation and truth table computation
194///
195/// Special reserved indices:
196/// - Index 0: Constant 0
197/// - Index 1: Constant 1
199public:
200 /// Special constant indices.
201 static constexpr uint32_t kConstant0 = 0;
202 static constexpr uint32_t kConstant1 = 1;
203
204 ArrayRef<LogicNetworkGate> getGates() const { return gates; }
205
207 // Reserve index 0 for constant 0 and index 1 for constant 1
208 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
209 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
210 // indexToValue needs placeholders for constants
211 indexToValue.push_back(Value()); // const0
212 indexToValue.push_back(Value()); // const1
213 }
214
215 /// Get a LogicEdge representing constant 0.
216 static Signal getConstant0() { return Signal(kConstant0, false); }
217
218 /// Get a LogicEdge representing constant 1 (constant 0 inverted).
219 static Signal getConstant1() { return Signal(kConstant0, true); }
220
221 /// Get or create an index for a value.
222 /// If the value doesn't have an index yet, assigns one and returns the index.
223 uint32_t getOrCreateIndex(Value value);
224
225 /// Get the raw index for a value. Asserts if value is not found.
226 /// Note: This returns only the index, not a Signal with inversion info.
227 /// Use hasIndex() to check existence first, or use getOrCreateIndex().
228 uint32_t getIndex(Value value) const;
229
230 /// Check if a value has been indexed.
231 bool hasIndex(Value value) const;
232
233 /// Get the value for a given raw index. Asserts if index is out of bounds.
234 /// Returns null Value for constant indices (0 and 1).
235 Value getValue(uint32_t index) const;
236
237 /// Fill values for the given raw indices.
238 void getValues(ArrayRef<uint32_t> indices,
239 SmallVectorImpl<Value> &values) const;
240
241 /// Get a Signal for a value.
242 /// Asserts if value not found - use hasIndex() first if unsure.
243 Signal getSignal(Value value, bool inverted) const {
244 return Signal(getIndex(value), inverted);
245 }
246
247 /// Get or create a Signal for a value.
248 Signal getOrCreateSignal(Value value, bool inverted) {
249 return Signal(getOrCreateIndex(value), inverted);
250 }
251
252 /// Get the value for a given Signal (extracts index from Signal).
253 Value getValue(Signal signal) const { return getValue(signal.getIndex()); }
254
255 /// Get the gate at a given index.
256 const LogicNetworkGate &getGate(uint32_t index) const { return gates[index]; }
257
258 /// Get mutable reference to gate at index.
259 LogicNetworkGate &getGate(uint32_t index) { return gates[index]; }
260
261 /// Get the total number of nodes in the network.
262 size_t size() const { return gates.size(); }
263
264 /// Add a primary input to the network.
265 uint32_t addPrimaryInput(Value value);
266
267 /// Add a gate with explicit result value and operand signals.
268 uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result,
269 llvm::ArrayRef<Signal> operands = {});
270
271 /// Add a gate using op->getResult(0) as the result value.
272 uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind,
273 llvm::ArrayRef<Signal> operands = {}) {
274 return addGate(op, kind, op->getResult(0), operands);
275 }
276
277 /// Build the logic network from a region/block in topological order.
278 /// Returns failure if the IR is not in a valid form.
279 LogicalResult buildFromBlock(Block *block);
280
281 /// Clear the network and reset to initial state.
282 void clear();
283
284private:
285 /// Map from MLIR Value to network index.
286 llvm::DenseMap<Value, uint32_t> valueToIndex;
287
288 /// Map from network index to MLIR Value.
289 llvm::SmallVector<Value> indexToValue;
290
291 /// Vector of all gates in the network.
292 llvm::SmallVector<LogicNetworkGate> gates;
293};
294
295/// Result of matching a cut against a pattern.
296///
297/// This structure contains the area and per-input delay information
298/// computed during pattern matching.
299///
300/// The delays can be stored in two ways:
301/// 1. As a reference to static/cached data (e.g., tech library delays)
302/// - Use setDelayRef() for zero-cost reference (no allocation)
303/// 2. As owned dynamic data (e.g., computed SOP delays)
304/// - Use setOwnedDelays() to transfer ownership
305///
307 /// Area cost of implementing this cut with the pattern.
308 double area = 0.0;
309
310 /// Default constructor.
311 MatchResult() = default;
312
313 /// Constructor with area and delays (by reference).
314 MatchResult(double area, ArrayRef<DelayType> delays)
315 : area(area), borrowedDelays(delays) {}
316
317 /// Set delays by reference (zero-cost for static/cached delays).
318 /// The caller must ensure the referenced data remains valid.
319 void setDelayRef(ArrayRef<DelayType> delays) { borrowedDelays = delays; }
320
321 /// Set delays by transferring ownership (for dynamically computed delays).
322 /// This moves the data into internal storage.
323 void setOwnedDelays(SmallVector<DelayType, 6> delays) {
324 ownedDelays.emplace(std::move(delays));
325 borrowedDelays = {};
326 }
327
328 /// Get all delays as an ArrayRef.
329 ArrayRef<DelayType> getDelays() const {
330 return ownedDelays.has_value() ? ArrayRef<DelayType>(*ownedDelays)
332 }
333
334private:
335 /// Borrowed delays (used when ownedDelays is empty).
336 /// Points to external data provided via setDelayRef().
337 ArrayRef<DelayType> borrowedDelays;
338
339 /// Owned delays (used when present).
340 /// Only allocated when setOwnedDelays() is called. When empty (std::nullopt),
341 /// moving this MatchResult avoids constructing/moving the SmallVector,
342 /// achieving zero-cost abstraction for the common case (borrowed delays).
343 std::optional<SmallVector<DelayType, 6>> ownedDelays;
344};
345
346/// Represents a cut that has been successfully matched to a rewriting pattern.
347///
348/// This class encapsulates the result of matching a cut against a rewriting
349/// pattern during optimization. It stores the matched pattern, the
350/// cut that was matched, and timing information needed for optimization.
352private:
353 const CutRewritePattern *pattern = nullptr; ///< The matched library pattern
354 SmallVector<DelayType, 1>
355 arrivalTimes; ///< Arrival times of outputs from this pattern
356 double area = 0.0; ///< Area cost of this pattern
357
358public:
359 /// Default constructor creates an invalid matched pattern.
360 MatchedPattern() = default;
361
362 /// Constructor for a valid matched pattern.
364 SmallVector<DelayType, 1> arrivalTimes, double area)
365 : pattern(pattern), arrivalTimes(std::move(arrivalTimes)), area(area) {}
366
367 /// Get the arrival time of signals through this pattern.
368 DelayType getArrivalTime(unsigned outputIndex) const;
369 ArrayRef<DelayType> getArrivalTimes() const;
370
371 /// Get the library pattern that was matched.
372 const CutRewritePattern *getPattern() const;
373
374 /// Get the area cost of using this pattern.
375 double getArea() const;
376};
377
378/// Represents a cut in the combinational logic network.
379///
380/// A cut is a subset of nodes in the combinational logic that forms a complete
381/// subgraph with a single output. It represents a portion of the circuit that
382/// can potentially be replaced with a single library gate or pattern.
383///
384/// The cut contains:
385/// - Input values: The boundary between the cut and the rest of the circuit
386/// - Operations: The logic operations within the cut boundary
387/// - Root operation: The output-driving operation of the cut
388///
389/// Cuts are used in combinational logic optimization to identify regions that
390/// can be optimized and replaced with more efficient implementations.
391class Cut {
392 /// Cached truth table for this cut.
393 /// Computed lazily when first accessed to avoid unnecessary computation.
394 mutable std::optional<BinaryTruthTable> truthTable;
395
396 /// Cached NPN canonical form for this cut.
397 /// Computed lazily from the truth table when first accessed.
398 mutable std::optional<NPNClass> npnClass;
399
400 std::optional<MatchedPattern> matchedPattern;
401
402 /// Root index in LogicNetwork (0 indicates no root for a trivial cut).
403 /// The root node produces the output of the cut.
404 uint32_t rootIndex = 0;
405
406 /// Signature bitset for fast cut size estimation.
407 /// Bit i is set if value with index i is in the cut's inputs.
408 /// This enables O(1) estimation of merged cut size using popcount.
409 uint64_t signature = 0;
410
411 /// Operand cuts used to create this cut (for lazy TT computation).
412 /// Stored to enable fast incremental truth table computation after
413 /// duplicate removal. Using raw pointers is safe since cuts are allocated
414 /// via bump allocator and live for the duration of enumeration.
415 llvm::SmallVector<const Cut *, 3> operandCuts;
416
417public:
418 Cut() = default;
419 Cut(uint32_t rootIndex, ArrayRef<uint32_t> inputs, uint64_t signature,
420 ArrayRef<const Cut *> operandCuts = {},
421 std::optional<BinaryTruthTable> truthTable = std::nullopt)
424 operandCuts(operandCuts.begin(), operandCuts.end()),
425 inputs(inputs.begin(), inputs.end()) {}
426
427 /// Create a trivial cut for a value.
428 static Cut getTrivialCut(uint32_t index);
429
430 /// External inputs to this cut (cut boundary).
431 /// Stored as LogicNetwork indices for efficient operations.
432 llvm::SmallVector<uint32_t, 6> inputs;
433
434 /// Check if this cut represents a trivial cut.
435 /// A trivial cut has no root operation and exactly one input.
436 bool isTrivialCut() const;
437
438 /// Get the root index in the LogicNetwork.
439 uint32_t getRootIndex() const { return rootIndex; }
440
441 /// Set the root index of this cut.
442 void setRootIndex(uint32_t idx) { rootIndex = idx; }
443
444 /// Get the signature of this cut.
445 uint64_t getSignature() const { return signature; }
446
447 /// Set the signature of this cut.
448 void setSignature(uint64_t sig) { signature = sig; }
449
450 /// Check if this cut dominates another (i.e., this cut's inputs are a subset
451 /// of the other's inputs). Uses signature pre-filtering for speed.
452 /// Both cuts must have sorted inputs.
453 bool dominates(const Cut &other) const;
454
455 /// Check if this cut dominates a set of sorted inputs with the given
456 /// signature.
457 bool dominates(ArrayRef<uint32_t> otherInputs, uint64_t otherSig) const;
458
459 void dump(llvm::raw_ostream &os, const LogicNetwork &network) const;
460
461 /// Get the number of inputs to this cut.
462 unsigned getInputSize() const;
463
464 /// Get the number of outputs from root operation.
465 unsigned getOutputSize(const LogicNetwork &network) const;
466
467 /// Get the truth table for this cut.
468 /// The truth table represents the boolean function computed by this cut.
469 const std::optional<BinaryTruthTable> &getTruthTable() const {
470 return truthTable;
471 }
472
473 /// Compute truth table using fast incremental method from operand cuts.
474 /// Trivial cuts are handled directly; non-trivial cuts require that
475 /// operand cuts have already been set via setOperandCuts.
476 void computeTruthTableFromOperands(const LogicNetwork &network);
477
478 /// Set the truth table directly (used for incremental computation).
479 void setTruthTable(BinaryTruthTable tt) { truthTable.emplace(std::move(tt)); }
480
481 /// Set operand cuts for lazy truth table computation.
482 void setOperandCuts(ArrayRef<const Cut *> cuts) {
483 operandCuts.assign(cuts.begin(), cuts.end());
484 }
485
486 /// Get operand cuts (for fast TT computation).
487 ArrayRef<const Cut *> getOperandCuts() const { return operandCuts; }
488
489 /// Get the NPN canonical form for this cut.
490 /// This is used for efficient pattern matching against library components.
491 const NPNClass &getNPNClass() const;
492 const NPNClass &getNPNClass(const NPNTable *npnTable) const;
493
494 /// Get the permutated inputs for this cut based on the given pattern NPN.
495 /// Returns indices into the inputs vector.
496 void
497 getPermutatedInputIndices(const NPNTable *npnTable,
498 const NPNClass &patternNPN,
499 SmallVectorImpl<unsigned> &permutedIndices) const;
500
501 /// Get arrival times for each input of this cut.
502 /// Returns failure if any input doesn't have a valid matched pattern.
503 LogicalResult getInputArrivalTimes(CutEnumerator &enumerator,
504 SmallVectorImpl<DelayType> &results) const;
505
506 /// Matched pattern for this cut.
510
511 /// Get the matched pattern for this cut.
512 const std::optional<MatchedPattern> &getMatchedPattern() const {
513 return matchedPattern;
514 }
515};
516
517/// Manages a collection of cuts for a single logic node using priority cuts
518/// algorithm.
519///
520/// Each node in the combinational logic network can have multiple cuts
521/// representing different ways to group it with surrounding logic. The CutSet
522/// manages these cuts and selects the best one based on the optimization
523/// strategy (area or timing).
524///
525/// The priority cuts algorithm maintains a bounded set of the most promising
526/// cuts to avoid exponential explosion while ensuring good optimization
527/// results.
528class CutSet {
529private:
530 llvm::SmallVector<Cut *, 12> cuts; ///< Collection of cuts for this node
531 Cut *bestCut = nullptr;
532 bool isFrozen = false; ///< Whether cut set is finalized
533
534public:
535 /// Check if this cut set has a valid matched pattern.
536 bool isMatched() const { return bestCut; }
537
538 /// Get the cut associated with the best matched pattern.
539 Cut *getBestMatchedCut() const;
540
541 /// Finalize the cut set by removing duplicates and selecting the best
542 /// pattern.
543 void finalize(
544 const CutRewriterOptions &options,
545 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut,
546 const LogicNetwork &logicNetwork);
547
548 /// Get the number of cuts in this set.
549 unsigned size() const;
550
551 /// Add a new cut to this set using bump allocator.
552 /// NOTE: The cut set must not be frozen
553 void addCut(Cut *cut);
554
555 /// Get read-only access to all cuts in this set.
556 ArrayRef<Cut *> getCuts() const;
557};
558
559/// Configuration options for the cut-based rewriting algorithm.
560///
561/// These options control various aspects of the rewriting process including
562/// optimization strategy, resource limits, and algorithmic parameters.
564 /// Optimization strategy (area vs. timing).
566
567 /// Maximum number of inputs allowed for any cut.
568 /// Larger cuts provide more optimization opportunities but increase
569 /// computational complexity exponentially.
571
572 /// Maximum number of cuts to maintain per logic node.
573 /// The priority cuts algorithm keeps only the most promising cuts
574 /// to prevent exponential explosion.
576
577 /// Fail if there is a root operation that has no matching pattern.
578 bool allowNoMatch = false;
579
580 /// Put arrival times to rewritten operations.
581 bool attachDebugTiming = false;
582
583 /// Run priority cuts enumeration and dump the cut sets.
584 bool testPriorityCuts = false;
585
586 /// Optional lookup table used to accelerate 4-input NPN canonicalization.
587 const NPNTable *npnTable = nullptr;
588};
589
590//===----------------------------------------------------------------------===//
591// Cut Enumeration Engine
592//===----------------------------------------------------------------------===//
593
595 uint64_t numCutsCreated = 0;
596 uint64_t numCutSetsCreated = 0;
597 uint64_t numCutsRewritten = 0;
598};
599
600template <typename T>
602public:
605
606 template <typename... Args>
607 T *create(Args &&...args) {
609 return new (allocator.Allocate()) T(std::forward<Args>(args)...);
610 }
611
612 void DestroyAll() { allocator.DestroyAll(); }
613
614private:
615 llvm::SpecificBumpPtrAllocator<T> allocator;
617};
618/// Cut enumeration engine for combinational logic networks.
619///
620/// The CutEnumerator is responsible for generating cuts for each node in a
621/// combinational logic network. It uses a priority cuts algorithm to maintain a
622/// bounded set of promising cuts while avoiding exponential explosion.
623///
624/// The enumeration process works by:
625/// 1. Visiting nodes in topological order
626/// 2. For each node, combining cuts from its inputs
627/// 3. Matching generated cuts against available patterns
628/// 4. Maintaining only the most promising cuts per node
630public:
631 /// Constructor for cut enumerator.
633
634 /// Enumerate cuts for all nodes in the given module.
635 ///
636 /// This is the main entry point that orchestrates the cut enumeration
637 /// process. It visits all operations in the module and generates cuts
638 /// for combinational logic operations.
639 LogicalResult enumerateCuts(
640 Operation *topOp,
641 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut =
642 [](const Cut &) { return std::nullopt; });
643
644 /// Create a new cut set for an index.
645 /// The index must not already have a cut set.
646 CutSet *createNewCutSet(uint32_t index);
647
648 /// Get the cut set for a specific index.
649 /// If not found, it means no cuts have been generated for this value yet.
650 /// In that case return a trivial cut set.
651 const CutSet *getCutSet(uint32_t index);
652
653 /// Clear all cut sets and reset the enumerator.
654 void clear();
655
656 /// Get the cut rewriter options used for this enumeration.
657 const CutRewriterOptions &getOptions() const { return options; }
658
659 /// Record that one cut was successfully rewritten.
661
662 void dump() const;
663
664 /// Get cut sets (indexed by LogicNetwork index).
665 const llvm::DenseMap<uint32_t, CutSet *> &getCutSets() const {
666 return cutSets;
667 }
668
669 /// Get the processing order.
670 ArrayRef<uint32_t> getProcessingOrder() const { return processingOrder; }
671
672private:
673 /// Visit a combinational logic operation and generate cuts.
674 /// This handles the core cut enumeration logic for operations
675 /// like AND, OR, XOR, etc.
676 LogicalResult visitLogicOp(uint32_t nodeIndex);
677
678 /// Maps indices to their associated cut sets.
679 /// CutSets are allocated from the bump allocator.
680 llvm::DenseMap<uint32_t, CutSet *> cutSets;
681
682 /// Typed bump allocators for fast allocation with destructors.
685
686 /// Indices in processing order.
687 llvm::SmallVector<uint32_t> processingOrder;
688
689 /// Configuration options for cut enumeration.
691
692 /// Function to match cuts against available patterns.
693 /// Set during enumeration and used when finalizing cut sets.
694 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut;
695
696 /// Flat logic network representation used during enumeration/rewrite.
698
699 /// Statistics for cut enumeration (number of cuts allocated, etc.).
701
702public:
703 /// Get the logic network (read-only).
704 const LogicNetwork &getLogicNetwork() const { return logicNetwork; }
705
706 /// Get the logic network (mutable).
708
709 /// Get enumeration statistics.
710 const CutEnumeratorStats &getStats() const { return stats; }
711};
712
713/// Base class for cut rewriting patterns used in combinational logic
714/// optimization.
715///
716/// A CutRewritePattern represents a library component or optimization pattern
717/// that can replace cuts in the combinational logic network. Each pattern
718/// defines:
719/// - How to recognize matching cuts and compute area/delay metrics
720/// - How to transform/replace the matched cuts
721///
722/// Patterns can use truth table matching for efficient recognition or
723/// implement custom matching logic for more complex cases.
725 CutRewritePattern(mlir::MLIRContext *context) : context(context) {}
726 /// Virtual destructor for base class.
727 virtual ~CutRewritePattern() = default;
728
729 /// Check if a cut matches this pattern and compute area/delay metrics.
730 ///
731 /// This method is called to determine if a cut can be replaced by this
732 /// pattern. If the cut matches, it should return a MatchResult containing
733 /// the area and per-input delays for this specific cut.
734 ///
735 /// If useTruthTableMatcher() returns true, this method is only
736 /// called for cuts with matching truth tables.
737 virtual std::optional<MatchResult> match(CutEnumerator &enumerator,
738 const Cut &cut) const = 0;
739
740 /// Specify truth tables that this pattern can match.
741 ///
742 /// If this method returns true, the pattern matcher will use truth table
743 /// comparison for efficient pre-filtering. Only cuts with matching truth
744 /// tables will be passed to the match() method. If it returns false, the
745 /// pattern will be checked against all cuts regardless of their truth tables.
746 /// This is useful for patterns that match regardless of their truth tables,
747 /// such as LUT-based patterns.
748 virtual bool
749 useTruthTableMatcher(SmallVectorImpl<NPNClass> &matchingNPNClasses) const;
750
751 /// Return a new operation that replaces the matched cut.
752 ///
753 /// Unlike MLIR's RewritePattern framework which allows arbitrary in-place
754 /// modifications, this method creates a new operation to replace the matched
755 /// cut rather than modifying existing operations. This constraint exists
756 /// because the cut enumerator maintains references to operations throughout
757 /// the circuit, making it safe to only replace the root operation of each
758 /// cut while preserving all other operations unchanged.
759 virtual FailureOr<Operation *> rewrite(mlir::OpBuilder &builder,
760 CutEnumerator &enumerator,
761 const Cut &cut) const = 0;
762
763 /// Get the number of outputs this pattern produces.
764 virtual unsigned getNumOutputs() const = 0;
765
766 /// Get the name of this pattern. Used for debugging.
767 virtual StringRef getPatternName() const { return "<unnamed>"; }
768
769 /// Get location for this pattern(optional).
770 virtual LocationAttr getLoc() const { return mlir::UnknownLoc::get(context); }
771
772 mlir::MLIRContext *getContext() const { return context; }
773
774private:
775 mlir::MLIRContext *context;
776};
777
778/// Manages a collection of rewriting patterns for combinational logic
779/// optimization.
780///
781/// This class organizes and provides efficient access to rewriting patterns
782/// used during cut-based optimization. It maintains:
783/// - A collection of all available patterns
784/// - Fast lookup tables for truth table-based matching
785/// - Separation of truth table vs. custom matching patterns
786///
787/// The pattern set is used by the CutRewriter to find suitable replacements
788/// for cuts in the combinational logic network.
790public:
791 /// Constructor that takes ownership of the provided patterns.
792 ///
793 /// During construction, patterns are analyzed and organized for efficient
794 /// lookup. Truth table matchers are indexed by their NPN canonical forms.
796 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns);
797
798private:
799 /// Owned collection of all rewriting patterns.
800 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns;
801
802 /// Fast lookup table mapping NPN canonical forms to matching patterns.
803 /// Each entry maps a truth table and input size to patterns that can handle
804 /// it.
805 DenseMap<std::pair<APInt, unsigned>,
806 SmallVector<std::pair<NPNClass, const CutRewritePattern *>>>
808
809 /// Patterns that use custom matching logic instead of NPN lookup.
810 /// These patterns are checked against every cut.
811 SmallVector<const CutRewritePattern *, 4> nonNPNPatterns;
812
813 /// CutRewriter needs access to internal data structures for pattern matching.
814 friend class CutRewriter;
815};
816
817/// Main cut-based rewriting algorithm for combinational logic optimization.
818///
819/// The CutRewriter implements a cut-based rewriting algorithm that:
820/// 1. Enumerates cuts in the combinational logic network using a priority cuts
821/// algorithm
822/// 2. Matches cuts against available rewriting patterns
823/// 3. Selects optimal patterns based on area or timing objectives
824/// 4. Rewrites the circuit using the selected patterns
825///
826/// The algorithm processes the network in topological order, building up cut
827/// sets for each node and selecting the best implementation based on the
828/// specified optimization strategy.
829///
830/// Usage example:
831/// ```cpp
832/// CutRewriterOptions options;
833/// options.strategy = OptimizationStrategy::Area;
834/// options.maxCutInputSize = 4;
835/// options.maxCutSizePerRoot = 8;
836///
837/// CutRewritePatternSet patterns(std::move(optimizationPatterns));
838/// CutRewriter rewriter(module, options, patterns);
839///
840/// if (failed(rewriter.run())) {
841/// // Handle rewriting failure
842/// }
843/// ```
845public:
846 /// Constructor for the cut rewriter.
849
850 /// Execute the complete cut-based rewriting algorithm.
851 ///
852 /// This method orchestrates the entire rewriting process:
853 /// 1. Enumerate cuts for all nodes in the combinational logic
854 /// 2. Match cuts against available patterns
855 /// 3. Select optimal patterns based on strategy
856 /// 4. Rewrite the circuit with selected patterns
857 LogicalResult run(Operation *topOp);
858
860 return cutEnumerator.getStats();
861 }
862
863private:
864 /// Enumerate cuts for all nodes in the given module.
865 /// Note: This preserves module boundaries and does not perform
866 /// rewriting across the hierarchy.
867 LogicalResult enumerateCuts(Operation *topOp);
868
869 /// Find patterns that match a cut's truth table.
870 ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
871 getMatchingPatternsFromTruthTable(const Cut &cut) const;
872
873 /// Match a cut against available patterns and compute arrival time.
874 std::optional<MatchedPattern> patternMatchCut(const Cut &cut);
875
876 /// Perform the actual circuit rewriting using selected patterns.
877 LogicalResult runBottomUpRewrite(Operation *topOp);
878
879 /// Configuration options
881
882 /// Available rewriting patterns
884
886};
887
888} // namespace synth
889} // namespace circt
890
891#endif // CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
assert(baseType &&"element must be base type")
RewritePatternSet pattern
Precomputed NPN canonicalization table for 4-input single-output functions.
Definition TruthTable.h:168
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
LogicNetwork & getLogicNetwork()
Get the logic network (mutable).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
TrackedSpecificBumpPtrAllocator< CutSet > cutSetAllocator
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
CutEnumeratorStats stats
Statistics for cut enumeration (number of cuts allocated, etc.).
void noteCutRewritten()
Record that one cut was successfully rewritten.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
const CutEnumeratorStats & getStats() const
Get enumeration statistics.
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
TrackedSpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
const CutRewriterOptions & getOptions() const
Get the cut rewriter options used for this enumeration.
Manages a collection of rewriting patterns for combinational logic optimization.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
Main cut-based rewriting algorithm for combinational logic optimization.
const CutEnumeratorStats & getStats() const
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
const CutRewritePatternSet & patterns
Available rewriting patterns.
CutEnumerator cutEnumerator
CutRewriter(const CutRewriterOptions &options, CutRewritePatternSet &patterns)
Constructor for the cut rewriter.
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
bool isMatched() const
Check if this cut set has a valid matched pattern.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
static Cut getTrivialCut(uint32_t index)
Create a trivial cut for a value.
void setRootIndex(uint32_t idx)
Set the root index of this cut.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
uint64_t signature
Signature bitset for fast cut size estimation.
uint64_t getSignature() const
Get the signature of this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
Cut(uint32_t rootIndex, ArrayRef< uint32_t > inputs, uint64_t signature, ArrayRef< const Cut * > operandCuts={}, std::optional< BinaryTruthTable > truthTable=std::nullopt)
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
void setOperandCuts(ArrayRef< const Cut * > cuts)
Set operand cuts for lazy truth table computation.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
void setSignature(uint64_t sig)
Set the signature of this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another (i.e., this cut's inputs are a subset of the other's inputs).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
unsigned getInputSize() const
Get the number of inputs to this cut.
ArrayRef< const Cut * > getOperandCuts() const
Get operand cuts (for fast TT computation).
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
Value getValue(Signal signal) const
Get the value for a given Signal (extracts index from Signal).
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
LogicNetworkGate & getGate(uint32_t index)
Get mutable reference to gate at index.
Signal getSignal(Value value, bool inverted) const
Get a Signal for a value.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, llvm::ArrayRef< Signal > operands={})
Add a gate using op->getResult(0) as the result value.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
static Signal getConstant0()
Get a LogicEdge representing constant 0.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
size_t size() const
Get the total number of nodes in the network.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
static Signal getConstant1()
Get a LogicEdge representing constant 1 (constant 0 inverted).
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
MatchedPattern(const CutRewritePattern *pattern, SmallVector< DelayType, 1 > arrivalTimes, double area)
Constructor for a valid matched pattern.
double getArea() const
Get the area cost of using this pattern.
MatchedPattern()=default
Default constructor creates an invalid matched pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
llvm::SpecificBumpPtrAllocator< T > allocator
TrackedSpecificBumpPtrAllocator(uint64_t &allocationCount)
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
int64_t DelayType
Definition CutRewriter.h:40
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:45
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
Represents a boolean function as a truth table.
Definition TruthTable.h:41
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:106
Base class for cut rewriting patterns used in combinational logic optimization.
virtual StringRef getPatternName() const
Get the name of this pattern. Used for debugging.
virtual FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const =0
Return a new operation that replaces the matched cut.
virtual ~CutRewritePattern()=default
Virtual destructor for base class.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
virtual std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const =0
Check if a cut matches this pattern and compute area/delay metrics.
mlir::MLIRContext * getContext() const
mlir::MLIRContext * context
CutRewritePattern(mlir::MLIRContext *context)
virtual unsigned getNumOutputs() const =0
Get the number of outputs this pattern produces.
virtual LocationAttr getLoc() const
Get location for this pattern(optional).
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
LogicNetworkGate(Operation *op, Kind kind, llvm::ArrayRef< Signal > operands={})
unsigned getNumFanins() const
Get the number of fanin edges based on kind.
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
llvm::PointerIntPair< Operation *, 3, Kind > opAndKind
Operation pointer and kind packed together.
bool isLogicGate() const
Check if this is a logic gate that can be part of a cut.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Dot3
Ordered DOT gate (3-input, synth.dot)
@ Identity
Identity gate (used for 1-input inverter)
@ Maj3
Reserved 3-input gate kind.
@ PrimaryInput
Primary input to the network.
@ Choice
Choice node (synth.choice)
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
void setDelayRef(ArrayRef< DelayType > delays)
Set delays by reference (zero-cost for static/cached delays).
MatchResult()=default
Default constructor.
void setOwnedDelays(SmallVector< DelayType, 6 > delays)
Set delays by transferring ownership (for dynamically computed delays).
double area
Area cost of implementing this cut with the pattern.
ArrayRef< DelayType > borrowedDelays
Borrowed delays (used when ownedDelays is empty).
std::optional< SmallVector< DelayType, 6 > > ownedDelays
Owned delays (used when present).
MatchResult(double area, ArrayRef< DelayType > delays)
Constructor with area and delays (by reference).
ArrayRef< DelayType > getDelays() const
Get all delays as an ArrayRef.
Edge representation in the logic network.
Definition CutRewriter.h:75
uint32_t getIndex() const
Get the node index (without the inversion bit).
Definition CutRewriter.h:84
Signal operator!() const
Create an inverted version of this edge.
Definition CutRewriter.h:95
bool operator<(const Signal &other) const
Definition CutRewriter.h:99
bool operator!=(const Signal &other) const
Definition CutRewriter.h:98
Signal flipInversion() const
Definition CutRewriter.h:92
uint32_t getRaw() const
Get the raw data (index << 1 | inverted).
Definition CutRewriter.h:90
bool isInverted() const
Check if this edge is inverted.
Definition CutRewriter.h:87
bool operator==(const Signal &other) const
Definition CutRewriter.h:97
Signal(uint32_t index, bool inverted)
Definition CutRewriter.h:79
Signal(uint32_t raw)
Definition CutRewriter.h:81