CIRCT 23.0.0git
Loading...
Searching...
No Matches
CutRewriter.h
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines a general cut-based rewriting framework for
10// combinational logic optimization. The framework uses NPN-equivalence matching
11// with area and delay metrics to rewrite cuts (subgraphs) in combinational
12// circuits with optimal patterns.
13//
14//===----------------------------------------------------------------------===//
15
16#ifndef CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
17#define CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
18
20#include "circt/Support/LLVM.h"
22#include "mlir/IR/Operation.h"
23#include "llvm/ADT/APInt.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/PointerIntPair.h"
26#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Allocator.h"
29#include "llvm/Support/LogicalResult.h"
30#include "llvm/Support/raw_ostream.h"
31#include <memory>
32#include <optional>
33#include <utility>
34
35namespace circt {
36namespace synth {
37// Type for representing delays in the circuit. It's user's responsibility to
38// use consistent units, i.e., all delays should be in the same unit (usually
39// femtoseconds, but not limited to it).
40using DelayType = int64_t;
41
42/// Maximum number of inputs supported for truth table generation.
43/// This limit prevents excessive memory usage as truth table size grows
44/// exponentially with the number of inputs (2^n entries).
45static constexpr unsigned maxTruthTableInputs = 16;
46
47// This is a helper function to sort operations topologically in a logic
48// network. This is necessary for cut rewriting to ensure that operations are
49// processed in the correct order, respecting dependencies.
50LogicalResult topologicallySortLogicNetwork(mlir::Operation *op);
51
52// Get the truth table for a specific operation within a block.
53// Block must be a SSACFG or topologically sorted.
54FailureOr<BinaryTruthTable> getTruthTable(ValueRange values, Block *block);
55
56//===----------------------------------------------------------------------===//
57// Cut Data Structures
58//===----------------------------------------------------------------------===//
59
60// Forward declarations
62class CutRewriter;
63class CutEnumerator;
66class LogicNetwork;
67
68//===----------------------------------------------------------------------===//
69// Logic Network Data Structures (Flat IR for efficient cut enumeration)
70//===----------------------------------------------------------------------===//
71
72/// Edge representation in the logic network.
73/// Similar to mockturtle's signal, this encodes both a node index and inversion
74/// in a single 32-bit value. The LSB indicates whether the signal is inverted.
75struct Signal {
76 uint32_t data = 0;
77
78 Signal() = default;
79 Signal(uint32_t index, bool inverted)
80 : data((index << 1) | (inverted ? 1 : 0)) {}
81 explicit Signal(uint32_t raw) : data(raw) {}
82
83 /// Get the node index (without the inversion bit).
84 uint32_t getIndex() const { return data >> 1; }
85
86 /// Check if this edge is inverted.
87 bool isInverted() const { return data & 1; }
88
89 /// Get the raw data (index << 1 | inverted).
90 uint32_t getRaw() const { return data; }
91
92 Signal flipInversion() const { return Signal(getIndex(), !isInverted()); }
93
94 /// Create an inverted version of this edge.
95 Signal operator!() const { return Signal(data ^ 1); }
96
97 bool operator==(const Signal &other) const { return data == other.data; }
98 bool operator!=(const Signal &other) const { return data != other.data; }
99 bool operator<(const Signal &other) const { return data < other.data; }
100};
101
102/// Represents a single gate/node in the flat logic network.
103/// This structure is designed to be cache-friendly and supports up to 3 inputs
104/// (sufficient for AND, XOR, MAJ gates). For nodes with fewer inputs, unused
105/// edges have index 0 (constant 0 node).
106///
107/// Special indices:
108/// - Index 0: Constant 0
109/// - Index 1: Constant 1
110///
111/// It uses 8 bytes for operation pointer + enum, 12 bytes for edges = 20
112/// bytes per gate.
114 /// Kind of logic gate.
115 enum Kind : uint8_t {
116 Constant = 0, ///< Constant 0/1 node (index 0 = const0, index 1 = const1)
117 PrimaryInput = 1, ///< Primary input to the network
118 And2 = 2, ///< AND gate (2-input, aig::AndInverterOp)
119 Xor2 = 3, ///< XOR gate (2-input)
120 Maj3 = 4, ///< Reserved 3-input gate kind
121 Identity = 5, ///< Identity gate (used for 1-input inverter)
122 Choice = 6 ///< Choice node (synth.choice)
123 };
124
125 /// Operation pointer and kind packed together.
126 /// The kind is stored in the low bits of the pointer.
127 llvm::PointerIntPair<Operation *, 3, Kind> opAndKind;
128
129 /// Fanin edges (up to 3 inputs). For AND gates, only edges[0] and edges[1]
130 /// are used. For PrimaryInput/Constant and Choice, none are used. The
131 /// inversion bit is encoded in each edge.
133
135 LogicNetworkGate(Operation *op, Kind kind,
136 llvm::ArrayRef<Signal> operands = {})
137 : opAndKind(op, kind), edges{} {
138 assert(operands.size() <= 3 && "Too many operands for LogicNetworkGate");
139 for (size_t i = 0; i < operands.size(); ++i)
140 edges[i] = operands[i];
141 }
142
143 /// Get the kind of this gate.
144 Kind getKind() const { return opAndKind.getInt(); }
145
146 /// Get the operation pointer (nullptr for constants).
147 Operation *getOperation() const { return opAndKind.getPointer(); }
148
149 /// Get the number of fanin edges based on kind.
150 unsigned getNumFanins() const {
151 switch (getKind()) {
152 case Constant:
153 case PrimaryInput:
154 return 0;
155 case And2:
156 case Xor2:
157 return 2;
158 case Maj3:
159 return 3;
160 case Identity:
161 return 1;
162 case Choice:
163 return 0;
164 }
165 llvm_unreachable("Unknown gate kind");
166 }
167
168 /// Check if this is a logic gate that can be part of a cut.
169 bool isLogicGate() const {
170 Kind k = getKind();
171 return k == And2 || k == Xor2 || k == Maj3 || k == Identity || k == Choice;
172 }
173
174 /// Check if this should always be a cut input (PI or constant).
175 bool isAlwaysCutInput() const {
176 Kind k = getKind();
177 return k == PrimaryInput || k == Constant;
178 }
179};
180
181/// Flat logic network representation for efficient cut enumeration.
182///
183/// This class provides a mockturtle-style flat representation of the
184/// combinational logic network. Each value in the MLIR IR is assigned a unique
185/// index, and gates are stored in a contiguous vector for cache efficiency.
186///
187/// The network supports:
188/// - O(1) lookup of gate information by index
189/// - Compact representation with inversion encoded in edges
190/// - Efficient simulation and truth table computation
191///
192/// Special reserved indices:
193/// - Index 0: Constant 0
194/// - Index 1: Constant 1
196public:
197 /// Special constant indices.
198 static constexpr uint32_t kConstant0 = 0;
199 static constexpr uint32_t kConstant1 = 1;
200
201 ArrayRef<LogicNetworkGate> getGates() const { return gates; }
202
204 // Reserve index 0 for constant 0 and index 1 for constant 1
205 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
206 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
207 // indexToValue needs placeholders for constants
208 indexToValue.push_back(Value()); // const0
209 indexToValue.push_back(Value()); // const1
210 }
211
212 /// Get a LogicEdge representing constant 0.
213 static Signal getConstant0() { return Signal(kConstant0, false); }
214
215 /// Get a LogicEdge representing constant 1 (constant 0 inverted).
216 static Signal getConstant1() { return Signal(kConstant0, true); }
217
218 /// Get or create an index for a value.
219 /// If the value doesn't have an index yet, assigns one and returns the index.
220 uint32_t getOrCreateIndex(Value value);
221
222 /// Get the raw index for a value. Asserts if value is not found.
223 /// Note: This returns only the index, not a Signal with inversion info.
224 /// Use hasIndex() to check existence first, or use getOrCreateIndex().
225 uint32_t getIndex(Value value) const;
226
227 /// Check if a value has been indexed.
228 bool hasIndex(Value value) const;
229
230 /// Get the value for a given raw index. Asserts if index is out of bounds.
231 /// Returns null Value for constant indices (0 and 1).
232 Value getValue(uint32_t index) const;
233
234 /// Fill values for the given raw indices.
235 void getValues(ArrayRef<uint32_t> indices,
236 SmallVectorImpl<Value> &values) const;
237
238 /// Get a Signal for a value.
239 /// Asserts if value not found - use hasIndex() first if unsure.
240 Signal getSignal(Value value, bool inverted) const {
241 return Signal(getIndex(value), inverted);
242 }
243
244 /// Get or create a Signal for a value.
245 Signal getOrCreateSignal(Value value, bool inverted) {
246 return Signal(getOrCreateIndex(value), inverted);
247 }
248
249 /// Get the value for a given Signal (extracts index from Signal).
250 Value getValue(Signal signal) const { return getValue(signal.getIndex()); }
251
252 /// Get the gate at a given index.
253 const LogicNetworkGate &getGate(uint32_t index) const { return gates[index]; }
254
255 /// Get mutable reference to gate at index.
256 LogicNetworkGate &getGate(uint32_t index) { return gates[index]; }
257
258 /// Get the total number of nodes in the network.
259 size_t size() const { return gates.size(); }
260
261 /// Add a primary input to the network.
262 uint32_t addPrimaryInput(Value value);
263
264 /// Add a gate with explicit result value and operand signals.
265 uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result,
266 llvm::ArrayRef<Signal> operands = {});
267
268 /// Add a gate using op->getResult(0) as the result value.
269 uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind,
270 llvm::ArrayRef<Signal> operands = {}) {
271 return addGate(op, kind, op->getResult(0), operands);
272 }
273
274 /// Build the logic network from a region/block in topological order.
275 /// Returns failure if the IR is not in a valid form.
276 LogicalResult buildFromBlock(Block *block);
277
278 /// Clear the network and reset to initial state.
279 void clear();
280
281private:
282 /// Map from MLIR Value to network index.
283 llvm::DenseMap<Value, uint32_t> valueToIndex;
284
285 /// Map from network index to MLIR Value.
286 llvm::SmallVector<Value> indexToValue;
287
288 /// Vector of all gates in the network.
289 llvm::SmallVector<LogicNetworkGate> gates;
290};
291
292/// Result of matching a cut against a pattern.
293///
294/// This structure contains the area and per-input delay information
295/// computed during pattern matching.
296///
297/// The delays can be stored in two ways:
298/// 1. As a reference to static/cached data (e.g., tech library delays)
299/// - Use setDelayRef() for zero-cost reference (no allocation)
300/// 2. As owned dynamic data (e.g., computed SOP delays)
301/// - Use setOwnedDelays() to transfer ownership
302///
304 /// Area cost of implementing this cut with the pattern.
305 double area = 0.0;
306
307 /// Default constructor.
308 MatchResult() = default;
309
310 /// Constructor with area and delays (by reference).
311 MatchResult(double area, ArrayRef<DelayType> delays)
312 : area(area), borrowedDelays(delays) {}
313
314 /// Set delays by reference (zero-cost for static/cached delays).
315 /// The caller must ensure the referenced data remains valid.
316 void setDelayRef(ArrayRef<DelayType> delays) { borrowedDelays = delays; }
317
318 /// Set delays by transferring ownership (for dynamically computed delays).
319 /// This moves the data into internal storage.
320 void setOwnedDelays(SmallVector<DelayType, 6> delays) {
321 ownedDelays.emplace(std::move(delays));
322 borrowedDelays = {};
323 }
324
325 /// Get all delays as an ArrayRef.
326 ArrayRef<DelayType> getDelays() const {
327 return ownedDelays.has_value() ? ArrayRef<DelayType>(*ownedDelays)
329 }
330
331private:
332 /// Borrowed delays (used when ownedDelays is empty).
333 /// Points to external data provided via setDelayRef().
334 ArrayRef<DelayType> borrowedDelays;
335
336 /// Owned delays (used when present).
337 /// Only allocated when setOwnedDelays() is called. When empty (std::nullopt),
338 /// moving this MatchResult avoids constructing/moving the SmallVector,
339 /// achieving zero-cost abstraction for the common case (borrowed delays).
340 std::optional<SmallVector<DelayType, 6>> ownedDelays;
341};
342
343/// Represents a cut that has been successfully matched to a rewriting pattern.
344///
345/// This class encapsulates the result of matching a cut against a rewriting
346/// pattern during optimization. It stores the matched pattern, the
347/// cut that was matched, and timing information needed for optimization.
349private:
350 const CutRewritePattern *pattern = nullptr; ///< The matched library pattern
351 SmallVector<DelayType, 1>
352 arrivalTimes; ///< Arrival times of outputs from this pattern
353 double area = 0.0; ///< Area cost of this pattern
354
355public:
356 /// Default constructor creates an invalid matched pattern.
357 MatchedPattern() = default;
358
359 /// Constructor for a valid matched pattern.
361 SmallVector<DelayType, 1> arrivalTimes, double area)
362 : pattern(pattern), arrivalTimes(std::move(arrivalTimes)), area(area) {}
363
364 /// Get the arrival time of signals through this pattern.
365 DelayType getArrivalTime(unsigned outputIndex) const;
366 ArrayRef<DelayType> getArrivalTimes() const;
367
368 /// Get the library pattern that was matched.
369 const CutRewritePattern *getPattern() const;
370
371 /// Get the area cost of using this pattern.
372 double getArea() const;
373};
374
375/// Represents a cut in the combinational logic network.
376///
377/// A cut is a subset of nodes in the combinational logic that forms a complete
378/// subgraph with a single output. It represents a portion of the circuit that
379/// can potentially be replaced with a single library gate or pattern.
380///
381/// The cut contains:
382/// - Input values: The boundary between the cut and the rest of the circuit
383/// - Operations: The logic operations within the cut boundary
384/// - Root operation: The output-driving operation of the cut
385///
386/// Cuts are used in combinational logic optimization to identify regions that
387/// can be optimized and replaced with more efficient implementations.
388class Cut {
389 /// Cached truth table for this cut.
390 /// Computed lazily when first accessed to avoid unnecessary computation.
391 mutable std::optional<BinaryTruthTable> truthTable;
392
393 /// Cached NPN canonical form for this cut.
394 /// Computed lazily from the truth table when first accessed.
395 mutable std::optional<NPNClass> npnClass;
396
397 std::optional<MatchedPattern> matchedPattern;
398
399 /// Root index in LogicNetwork (0 indicates no root for a trivial cut).
400 /// The root node produces the output of the cut.
401 uint32_t rootIndex = 0;
402
403 /// Signature bitset for fast cut size estimation.
404 /// Bit i is set if value with index i is in the cut's inputs.
405 /// This enables O(1) estimation of merged cut size using popcount.
406 uint64_t signature = 0;
407
408 /// Operand cuts used to create this cut (for lazy TT computation).
409 /// Stored to enable fast incremental truth table computation after
410 /// duplicate removal. Using raw pointers is safe since cuts are allocated
411 /// via bump allocator and live for the duration of enumeration.
412 llvm::SmallVector<const Cut *, 3> operandCuts;
413
414public:
415 Cut() = default;
416 Cut(uint32_t rootIndex, ArrayRef<uint32_t> inputs, uint64_t signature,
417 ArrayRef<const Cut *> operandCuts = {},
418 std::optional<BinaryTruthTable> truthTable = std::nullopt)
421 operandCuts(operandCuts.begin(), operandCuts.end()),
422 inputs(inputs.begin(), inputs.end()) {}
423
424 /// Create a trivial cut for a value.
425 static Cut getTrivialCut(uint32_t index);
426
427 /// External inputs to this cut (cut boundary).
428 /// Stored as LogicNetwork indices for efficient operations.
429 llvm::SmallVector<uint32_t, 6> inputs;
430
431 /// Check if this cut represents a trivial cut.
432 /// A trivial cut has no root operation and exactly one input.
433 bool isTrivialCut() const;
434
435 /// Get the root index in the LogicNetwork.
436 uint32_t getRootIndex() const { return rootIndex; }
437
438 /// Set the root index of this cut.
439 void setRootIndex(uint32_t idx) { rootIndex = idx; }
440
441 /// Get the signature of this cut.
442 uint64_t getSignature() const { return signature; }
443
444 /// Set the signature of this cut.
445 void setSignature(uint64_t sig) { signature = sig; }
446
447 /// Check if this cut dominates another (i.e., this cut's inputs are a subset
448 /// of the other's inputs). Uses signature pre-filtering for speed.
449 /// Both cuts must have sorted inputs.
450 bool dominates(const Cut &other) const;
451
452 /// Check if this cut dominates a set of sorted inputs with the given
453 /// signature.
454 bool dominates(ArrayRef<uint32_t> otherInputs, uint64_t otherSig) const;
455
456 void dump(llvm::raw_ostream &os, const LogicNetwork &network) const;
457
458 /// Get the number of inputs to this cut.
459 unsigned getInputSize() const;
460
461 /// Get the number of outputs from root operation.
462 unsigned getOutputSize(const LogicNetwork &network) const;
463
464 /// Get the truth table for this cut.
465 /// The truth table represents the boolean function computed by this cut.
466 const std::optional<BinaryTruthTable> &getTruthTable() const {
467 return truthTable;
468 }
469
470 /// Compute truth table using fast incremental method from operand cuts.
471 /// Trivial cuts are handled directly; non-trivial cuts require that
472 /// operand cuts have already been set via setOperandCuts.
473 void computeTruthTableFromOperands(const LogicNetwork &network);
474
475 /// Set the truth table directly (used for incremental computation).
476 void setTruthTable(BinaryTruthTable tt) { truthTable.emplace(std::move(tt)); }
477
478 /// Set operand cuts for lazy truth table computation.
479 void setOperandCuts(ArrayRef<const Cut *> cuts) {
480 operandCuts.assign(cuts.begin(), cuts.end());
481 }
482
483 /// Get operand cuts (for fast TT computation).
484 ArrayRef<const Cut *> getOperandCuts() const { return operandCuts; }
485
486 /// Get the NPN canonical form for this cut.
487 /// This is used for efficient pattern matching against library components.
488 const NPNClass &getNPNClass() const;
489 const NPNClass &getNPNClass(const NPNTable *npnTable) const;
490
491 /// Get the permutated inputs for this cut based on the given pattern NPN.
492 /// Returns indices into the inputs vector.
493 void
494 getPermutatedInputIndices(const NPNTable *npnTable,
495 const NPNClass &patternNPN,
496 SmallVectorImpl<unsigned> &permutedIndices) const;
497
498 /// Get arrival times for each input of this cut.
499 /// Returns failure if any input doesn't have a valid matched pattern.
500 LogicalResult getInputArrivalTimes(CutEnumerator &enumerator,
501 SmallVectorImpl<DelayType> &results) const;
502
503 /// Matched pattern for this cut.
507
508 /// Get the matched pattern for this cut.
509 const std::optional<MatchedPattern> &getMatchedPattern() const {
510 return matchedPattern;
511 }
512};
513
514/// Manages a collection of cuts for a single logic node using priority cuts
515/// algorithm.
516///
517/// Each node in the combinational logic network can have multiple cuts
518/// representing different ways to group it with surrounding logic. The CutSet
519/// manages these cuts and selects the best one based on the optimization
520/// strategy (area or timing).
521///
522/// The priority cuts algorithm maintains a bounded set of the most promising
523/// cuts to avoid exponential explosion while ensuring good optimization
524/// results.
525class CutSet {
526private:
527 llvm::SmallVector<Cut *, 12> cuts; ///< Collection of cuts for this node
528 Cut *bestCut = nullptr;
529 bool isFrozen = false; ///< Whether cut set is finalized
530
531public:
532 /// Check if this cut set has a valid matched pattern.
533 bool isMatched() const { return bestCut; }
534
535 /// Get the cut associated with the best matched pattern.
536 Cut *getBestMatchedCut() const;
537
538 /// Finalize the cut set by removing duplicates and selecting the best
539 /// pattern.
540 void finalize(
541 const CutRewriterOptions &options,
542 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut,
543 const LogicNetwork &logicNetwork);
544
545 /// Get the number of cuts in this set.
546 unsigned size() const;
547
548 /// Add a new cut to this set using bump allocator.
549 /// NOTE: The cut set must not be frozen
550 void addCut(Cut *cut);
551
552 /// Get read-only access to all cuts in this set.
553 ArrayRef<Cut *> getCuts() const;
554};
555
556/// Configuration options for the cut-based rewriting algorithm.
557///
558/// These options control various aspects of the rewriting process including
559/// optimization strategy, resource limits, and algorithmic parameters.
561 /// Optimization strategy (area vs. timing).
563
564 /// Maximum number of inputs allowed for any cut.
565 /// Larger cuts provide more optimization opportunities but increase
566 /// computational complexity exponentially.
568
569 /// Maximum number of cuts to maintain per logic node.
570 /// The priority cuts algorithm keeps only the most promising cuts
571 /// to prevent exponential explosion.
573
574 /// Fail if there is a root operation that has no matching pattern.
575 bool allowNoMatch = false;
576
577 /// Put arrival times to rewritten operations.
578 bool attachDebugTiming = false;
579
580 /// Run priority cuts enumeration and dump the cut sets.
581 bool testPriorityCuts = false;
582
583 /// Optional lookup table used to accelerate 4-input NPN canonicalization.
584 const NPNTable *npnTable = nullptr;
585};
586
587//===----------------------------------------------------------------------===//
588// Cut Enumeration Engine
589//===----------------------------------------------------------------------===//
590
592 uint64_t numCutsCreated = 0;
593 uint64_t numCutSetsCreated = 0;
594 uint64_t numCutsRewritten = 0;
595};
596
597template <typename T>
599public:
602
603 template <typename... Args>
604 T *create(Args &&...args) {
606 return new (allocator.Allocate()) T(std::forward<Args>(args)...);
607 }
608
609 void DestroyAll() { allocator.DestroyAll(); }
610
611private:
612 llvm::SpecificBumpPtrAllocator<T> allocator;
614};
615/// Cut enumeration engine for combinational logic networks.
616///
617/// The CutEnumerator is responsible for generating cuts for each node in a
618/// combinational logic network. It uses a priority cuts algorithm to maintain a
619/// bounded set of promising cuts while avoiding exponential explosion.
620///
621/// The enumeration process works by:
622/// 1. Visiting nodes in topological order
623/// 2. For each node, combining cuts from its inputs
624/// 3. Matching generated cuts against available patterns
625/// 4. Maintaining only the most promising cuts per node
627public:
628 /// Constructor for cut enumerator.
630
631 /// Enumerate cuts for all nodes in the given module.
632 ///
633 /// This is the main entry point that orchestrates the cut enumeration
634 /// process. It visits all operations in the module and generates cuts
635 /// for combinational logic operations.
636 LogicalResult enumerateCuts(
637 Operation *topOp,
638 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut =
639 [](const Cut &) { return std::nullopt; });
640
641 /// Create a new cut set for an index.
642 /// The index must not already have a cut set.
643 CutSet *createNewCutSet(uint32_t index);
644
645 /// Get the cut set for a specific index.
646 /// If not found, it means no cuts have been generated for this value yet.
647 /// In that case return a trivial cut set.
648 const CutSet *getCutSet(uint32_t index);
649
650 /// Clear all cut sets and reset the enumerator.
651 void clear();
652
653 /// Get the cut rewriter options used for this enumeration.
654 const CutRewriterOptions &getOptions() const { return options; }
655
656 /// Record that one cut was successfully rewritten.
658
659 void dump() const;
660
661 /// Get cut sets (indexed by LogicNetwork index).
662 const llvm::DenseMap<uint32_t, CutSet *> &getCutSets() const {
663 return cutSets;
664 }
665
666 /// Get the processing order.
667 ArrayRef<uint32_t> getProcessingOrder() const { return processingOrder; }
668
669private:
670 /// Visit a combinational logic operation and generate cuts.
671 /// This handles the core cut enumeration logic for operations
672 /// like AND, OR, XOR, etc.
673 LogicalResult visitLogicOp(uint32_t nodeIndex);
674
675 /// Maps indices to their associated cut sets.
676 /// CutSets are allocated from the bump allocator.
677 llvm::DenseMap<uint32_t, CutSet *> cutSets;
678
679 /// Typed bump allocators for fast allocation with destructors.
682
683 /// Indices in processing order.
684 llvm::SmallVector<uint32_t> processingOrder;
685
686 /// Configuration options for cut enumeration.
688
689 /// Function to match cuts against available patterns.
690 /// Set during enumeration and used when finalizing cut sets.
691 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut;
692
693 /// Flat logic network representation used during enumeration/rewrite.
695
696 /// Statistics for cut enumeration (number of cuts allocated, etc.).
698
699public:
700 /// Get the logic network (read-only).
701 const LogicNetwork &getLogicNetwork() const { return logicNetwork; }
702
703 /// Get the logic network (mutable).
705
706 /// Get enumeration statistics.
707 const CutEnumeratorStats &getStats() const { return stats; }
708};
709
710/// Base class for cut rewriting patterns used in combinational logic
711/// optimization.
712///
713/// A CutRewritePattern represents a library component or optimization pattern
714/// that can replace cuts in the combinational logic network. Each pattern
715/// defines:
716/// - How to recognize matching cuts and compute area/delay metrics
717/// - How to transform/replace the matched cuts
718///
719/// Patterns can use truth table matching for efficient recognition or
720/// implement custom matching logic for more complex cases.
722 CutRewritePattern(mlir::MLIRContext *context) : context(context) {}
723 /// Virtual destructor for base class.
724 virtual ~CutRewritePattern() = default;
725
726 /// Check if a cut matches this pattern and compute area/delay metrics.
727 ///
728 /// This method is called to determine if a cut can be replaced by this
729 /// pattern. If the cut matches, it should return a MatchResult containing
730 /// the area and per-input delays for this specific cut.
731 ///
732 /// If useTruthTableMatcher() returns true, this method is only
733 /// called for cuts with matching truth tables.
734 virtual std::optional<MatchResult> match(CutEnumerator &enumerator,
735 const Cut &cut) const = 0;
736
737 /// Specify truth tables that this pattern can match.
738 ///
739 /// If this method returns true, the pattern matcher will use truth table
740 /// comparison for efficient pre-filtering. Only cuts with matching truth
741 /// tables will be passed to the match() method. If it returns false, the
742 /// pattern will be checked against all cuts regardless of their truth tables.
743 /// This is useful for patterns that match regardless of their truth tables,
744 /// such as LUT-based patterns.
745 virtual bool
746 useTruthTableMatcher(SmallVectorImpl<NPNClass> &matchingNPNClasses) const;
747
748 /// Return a new operation that replaces the matched cut.
749 ///
750 /// Unlike MLIR's RewritePattern framework which allows arbitrary in-place
751 /// modifications, this method creates a new operation to replace the matched
752 /// cut rather than modifying existing operations. This constraint exists
753 /// because the cut enumerator maintains references to operations throughout
754 /// the circuit, making it safe to only replace the root operation of each
755 /// cut while preserving all other operations unchanged.
756 virtual FailureOr<Operation *> rewrite(mlir::OpBuilder &builder,
757 CutEnumerator &enumerator,
758 const Cut &cut) const = 0;
759
760 /// Get the number of outputs this pattern produces.
761 virtual unsigned getNumOutputs() const = 0;
762
763 /// Get the name of this pattern. Used for debugging.
764 virtual StringRef getPatternName() const { return "<unnamed>"; }
765
766 /// Get location for this pattern(optional).
767 virtual LocationAttr getLoc() const { return mlir::UnknownLoc::get(context); }
768
769 mlir::MLIRContext *getContext() const { return context; }
770
771private:
772 mlir::MLIRContext *context;
773};
774
775/// Manages a collection of rewriting patterns for combinational logic
776/// optimization.
777///
778/// This class organizes and provides efficient access to rewriting patterns
779/// used during cut-based optimization. It maintains:
780/// - A collection of all available patterns
781/// - Fast lookup tables for truth table-based matching
782/// - Separation of truth table vs. custom matching patterns
783///
784/// The pattern set is used by the CutRewriter to find suitable replacements
785/// for cuts in the combinational logic network.
787public:
788 /// Constructor that takes ownership of the provided patterns.
789 ///
790 /// During construction, patterns are analyzed and organized for efficient
791 /// lookup. Truth table matchers are indexed by their NPN canonical forms.
793 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns);
794
795private:
796 /// Owned collection of all rewriting patterns.
797 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns;
798
799 /// Fast lookup table mapping NPN canonical forms to matching patterns.
800 /// Each entry maps a truth table and input size to patterns that can handle
801 /// it.
802 DenseMap<std::pair<APInt, unsigned>,
803 SmallVector<std::pair<NPNClass, const CutRewritePattern *>>>
805
806 /// Patterns that use custom matching logic instead of NPN lookup.
807 /// These patterns are checked against every cut.
808 SmallVector<const CutRewritePattern *, 4> nonNPNPatterns;
809
810 /// CutRewriter needs access to internal data structures for pattern matching.
811 friend class CutRewriter;
812};
813
814/// Main cut-based rewriting algorithm for combinational logic optimization.
815///
816/// The CutRewriter implements a cut-based rewriting algorithm that:
817/// 1. Enumerates cuts in the combinational logic network using a priority cuts
818/// algorithm
819/// 2. Matches cuts against available rewriting patterns
820/// 3. Selects optimal patterns based on area or timing objectives
821/// 4. Rewrites the circuit using the selected patterns
822///
823/// The algorithm processes the network in topological order, building up cut
824/// sets for each node and selecting the best implementation based on the
825/// specified optimization strategy.
826///
827/// Usage example:
828/// ```cpp
829/// CutRewriterOptions options;
830/// options.strategy = OptimizationStrategy::Area;
831/// options.maxCutInputSize = 4;
832/// options.maxCutSizePerRoot = 8;
833///
834/// CutRewritePatternSet patterns(std::move(optimizationPatterns));
835/// CutRewriter rewriter(module, options, patterns);
836///
837/// if (failed(rewriter.run())) {
838/// // Handle rewriting failure
839/// }
840/// ```
842public:
843 /// Constructor for the cut rewriter.
846
847 /// Execute the complete cut-based rewriting algorithm.
848 ///
849 /// This method orchestrates the entire rewriting process:
850 /// 1. Enumerate cuts for all nodes in the combinational logic
851 /// 2. Match cuts against available patterns
852 /// 3. Select optimal patterns based on strategy
853 /// 4. Rewrite the circuit with selected patterns
854 LogicalResult run(Operation *topOp);
855
857 return cutEnumerator.getStats();
858 }
859
860private:
861 /// Enumerate cuts for all nodes in the given module.
862 /// Note: This preserves module boundaries and does not perform
863 /// rewriting across the hierarchy.
864 LogicalResult enumerateCuts(Operation *topOp);
865
866 /// Find patterns that match a cut's truth table.
867 ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
868 getMatchingPatternsFromTruthTable(const Cut &cut) const;
869
870 /// Match a cut against available patterns and compute arrival time.
871 std::optional<MatchedPattern> patternMatchCut(const Cut &cut);
872
873 /// Perform the actual circuit rewriting using selected patterns.
874 LogicalResult runBottomUpRewrite(Operation *topOp);
875
876 /// Configuration options
878
879 /// Available rewriting patterns
881
883};
884
885} // namespace synth
886} // namespace circt
887
888#endif // CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
assert(baseType &&"element must be base type")
RewritePatternSet pattern
Precomputed NPN canonicalization table for 4-input single-output functions.
Definition TruthTable.h:168
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
LogicNetwork & getLogicNetwork()
Get the logic network (mutable).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
TrackedSpecificBumpPtrAllocator< CutSet > cutSetAllocator
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
CutEnumeratorStats stats
Statistics for cut enumeration (number of cuts allocated, etc.).
void noteCutRewritten()
Record that one cut was successfully rewritten.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
const CutEnumeratorStats & getStats() const
Get enumeration statistics.
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
TrackedSpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
const CutRewriterOptions & getOptions() const
Get the cut rewriter options used for this enumeration.
Manages a collection of rewriting patterns for combinational logic optimization.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
Main cut-based rewriting algorithm for combinational logic optimization.
const CutEnumeratorStats & getStats() const
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
const CutRewritePatternSet & patterns
Available rewriting patterns.
CutEnumerator cutEnumerator
CutRewriter(const CutRewriterOptions &options, CutRewritePatternSet &patterns)
Constructor for the cut rewriter.
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
bool isMatched() const
Check if this cut set has a valid matched pattern.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
static Cut getTrivialCut(uint32_t index)
Create a trivial cut for a value.
void setRootIndex(uint32_t idx)
Set the root index of this cut.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
uint64_t signature
Signature bitset for fast cut size estimation.
uint64_t getSignature() const
Get the signature of this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
Cut(uint32_t rootIndex, ArrayRef< uint32_t > inputs, uint64_t signature, ArrayRef< const Cut * > operandCuts={}, std::optional< BinaryTruthTable > truthTable=std::nullopt)
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
void setOperandCuts(ArrayRef< const Cut * > cuts)
Set operand cuts for lazy truth table computation.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
void setSignature(uint64_t sig)
Set the signature of this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another (i.e., this cut's inputs are a subset of the other's inputs).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
unsigned getInputSize() const
Get the number of inputs to this cut.
ArrayRef< const Cut * > getOperandCuts() const
Get operand cuts (for fast TT computation).
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
Value getValue(Signal signal) const
Get the value for a given Signal (extracts index from Signal).
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
LogicNetworkGate & getGate(uint32_t index)
Get mutable reference to gate at index.
Signal getSignal(Value value, bool inverted) const
Get a Signal for a value.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, llvm::ArrayRef< Signal > operands={})
Add a gate using op->getResult(0) as the result value.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
static Signal getConstant0()
Get a LogicEdge representing constant 0.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
size_t size() const
Get the total number of nodes in the network.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
static Signal getConstant1()
Get a LogicEdge representing constant 1 (constant 0 inverted).
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
MatchedPattern(const CutRewritePattern *pattern, SmallVector< DelayType, 1 > arrivalTimes, double area)
Constructor for a valid matched pattern.
double getArea() const
Get the area cost of using this pattern.
MatchedPattern()=default
Default constructor creates an invalid matched pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
llvm::SpecificBumpPtrAllocator< T > allocator
TrackedSpecificBumpPtrAllocator(uint64_t &allocationCount)
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
int64_t DelayType
Definition CutRewriter.h:40
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:45
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
Represents a boolean function as a truth table.
Definition TruthTable.h:41
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:106
Base class for cut rewriting patterns used in combinational logic optimization.
virtual StringRef getPatternName() const
Get the name of this pattern. Used for debugging.
virtual FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const =0
Return a new operation that replaces the matched cut.
virtual ~CutRewritePattern()=default
Virtual destructor for base class.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
virtual std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const =0
Check if a cut matches this pattern and compute area/delay metrics.
mlir::MLIRContext * getContext() const
mlir::MLIRContext * context
CutRewritePattern(mlir::MLIRContext *context)
virtual unsigned getNumOutputs() const =0
Get the number of outputs this pattern produces.
virtual LocationAttr getLoc() const
Get location for this pattern(optional).
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
LogicNetworkGate(Operation *op, Kind kind, llvm::ArrayRef< Signal > operands={})
unsigned getNumFanins() const
Get the number of fanin edges based on kind.
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
llvm::PointerIntPair< Operation *, 3, Kind > opAndKind
Operation pointer and kind packed together.
bool isLogicGate() const
Check if this is a logic gate that can be part of a cut.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Identity
Identity gate (used for 1-input inverter)
@ Maj3
Reserved 3-input gate kind.
@ PrimaryInput
Primary input to the network.
@ Choice
Choice node (synth.choice)
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
void setDelayRef(ArrayRef< DelayType > delays)
Set delays by reference (zero-cost for static/cached delays).
MatchResult()=default
Default constructor.
void setOwnedDelays(SmallVector< DelayType, 6 > delays)
Set delays by transferring ownership (for dynamically computed delays).
double area
Area cost of implementing this cut with the pattern.
ArrayRef< DelayType > borrowedDelays
Borrowed delays (used when ownedDelays is empty).
std::optional< SmallVector< DelayType, 6 > > ownedDelays
Owned delays (used when present).
MatchResult(double area, ArrayRef< DelayType > delays)
Constructor with area and delays (by reference).
ArrayRef< DelayType > getDelays() const
Get all delays as an ArrayRef.
Edge representation in the logic network.
Definition CutRewriter.h:75
uint32_t getIndex() const
Get the node index (without the inversion bit).
Definition CutRewriter.h:84
Signal operator!() const
Create an inverted version of this edge.
Definition CutRewriter.h:95
bool operator<(const Signal &other) const
Definition CutRewriter.h:99
bool operator!=(const Signal &other) const
Definition CutRewriter.h:98
Signal flipInversion() const
Definition CutRewriter.h:92
uint32_t getRaw() const
Get the raw data (index << 1 | inverted).
Definition CutRewriter.h:90
bool isInverted() const
Check if this edge is inverted.
Definition CutRewriter.h:87
bool operator==(const Signal &other) const
Definition CutRewriter.h:97
Signal(uint32_t index, bool inverted)
Definition CutRewriter.h:79
Signal(uint32_t raw)
Definition CutRewriter.h:81