CIRCT 23.0.0git
Loading...
Searching...
No Matches
CutRewriter.h
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines a general cut-based rewriting framework for
10// combinational logic optimization. The framework uses NPN-equivalence matching
11// with area and delay metrics to rewrite cuts (subgraphs) in combinational
12// circuits with optimal patterns.
13//
14//===----------------------------------------------------------------------===//
15
16#ifndef CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
17#define CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
18
20#include "circt/Support/LLVM.h"
22#include "mlir/IR/Operation.h"
23#include "llvm/ADT/APInt.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/PointerIntPair.h"
26#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Allocator.h"
29#include "llvm/Support/LogicalResult.h"
30#include "llvm/Support/raw_ostream.h"
31#include <memory>
32#include <optional>
33#include <utility>
34
35namespace circt {
36namespace synth {
37// Type for representing delays in the circuit. It's user's responsibility to
38// use consistent units, i.e., all delays should be in the same unit (usually
39// femtoseconds, but not limited to it).
40using DelayType = int64_t;
41
42/// Maximum number of inputs supported for truth table generation.
43/// This limit prevents excessive memory usage as truth table size grows
44/// exponentially with the number of inputs (2^n entries).
45static constexpr unsigned maxTruthTableInputs = 16;
46
47// This is a helper function to sort operations topologically in a logic
48// network. This is necessary for cut rewriting to ensure that operations are
49// processed in the correct order, respecting dependencies.
50LogicalResult topologicallySortLogicNetwork(mlir::Operation *op);
51
52// Get the truth table for a specific operation within a block.
53// Block must be a SSACFG or topologically sorted.
54FailureOr<BinaryTruthTable> getTruthTable(ValueRange values, Block *block);
55
56//===----------------------------------------------------------------------===//
57// Cut Data Structures
58//===----------------------------------------------------------------------===//
59
60// Forward declarations
62class CutRewriter;
63class CutEnumerator;
66class LogicNetwork;
67
68//===----------------------------------------------------------------------===//
69// Logic Network Data Structures (Flat IR for efficient cut enumeration)
70//===----------------------------------------------------------------------===//
71
72/// Edge representation in the logic network.
73/// Similar to mockturtle's signal, this encodes both a node index and inversion
74/// in a single 32-bit value. The LSB indicates whether the signal is inverted.
75struct Signal {
76 uint32_t data = 0;
77
78 Signal() = default;
79 Signal(uint32_t index, bool inverted)
80 : data((index << 1) | (inverted ? 1 : 0)) {}
81 explicit Signal(uint32_t raw) : data(raw) {}
82
83 /// Get the node index (without the inversion bit).
84 uint32_t getIndex() const { return data >> 1; }
85
86 /// Check if this edge is inverted.
87 bool isInverted() const { return data & 1; }
88
89 /// Get the raw data (index << 1 | inverted).
90 uint32_t getRaw() const { return data; }
91
92 Signal flipInversion() const { return Signal(getIndex(), !isInverted()); }
93
94 /// Create an inverted version of this edge.
95 Signal operator!() const { return Signal(data ^ 1); }
96
97 bool operator==(const Signal &other) const { return data == other.data; }
98 bool operator!=(const Signal &other) const { return data != other.data; }
99 bool operator<(const Signal &other) const { return data < other.data; }
100};
101
102/// Represents a single gate/node in the flat logic network.
103/// This structure is designed to be cache-friendly and supports up to 3 inputs
104/// (sufficient for AND, XOR, MAJ gates). For nodes with fewer inputs, unused
105/// edges have index 0 (constant 0 node).
106///
107/// Special indices:
108/// - Index 0: Constant 0
109/// - Index 1: Constant 1
110///
111/// It uses 8 bytes for operation pointer + enum, 12 bytes for edges = 20
112/// bytes per gate.
114 /// Kind of logic gate.
115 enum Kind : uint8_t {
116 Constant = 0, ///< Constant 0/1 node (index 0 = const0, index 1 = const1)
117 PrimaryInput = 1, ///< Primary input to the network
118 And2 = 2, ///< AND gate (2-input, aig::AndInverterOp)
119 Xor2 = 3, ///< XOR gate (2-input)
120 Maj3 = 4, ///< Majority gate (3-input, mig::MajOp)
121 Identity = 5 ///< Identity gate (used for 1-input inverter)
122 };
123
124 /// Operation pointer and kind packed together.
125 /// The kind is stored in the low bits of the pointer.
126 llvm::PointerIntPair<Operation *, 3, Kind> opAndKind;
127
128 /// Fanin edges (up to 3 inputs). For AND gates, only edges[0] and edges[1]
129 /// are used. For MAJ gates, all three are used. For PrimaryInput/Constant,
130 /// none are used. The inversion bit is encoded in each edge.
132
134 LogicNetworkGate(Operation *op, Kind kind,
135 llvm::ArrayRef<Signal> operands = {})
136 : opAndKind(op, kind), edges{} {
137 assert(operands.size() <= 3 && "Too many operands for LogicNetworkGate");
138 for (size_t i = 0; i < operands.size(); ++i)
139 edges[i] = operands[i];
140 }
141
142 /// Get the kind of this gate.
143 Kind getKind() const { return opAndKind.getInt(); }
144
145 /// Get the operation pointer (nullptr for constants).
146 Operation *getOperation() const { return opAndKind.getPointer(); }
147
148 /// Get the number of fanin edges based on kind.
149 unsigned getNumFanins() const {
150 switch (getKind()) {
151 case Constant:
152 case PrimaryInput:
153 return 0;
154 case And2:
155 case Xor2:
156 return 2;
157 case Maj3:
158 return 3;
159 case Identity:
160 return 1;
161 }
162 llvm_unreachable("Unknown gate kind");
163 }
164
165 /// Check if this is a logic gate that can be part of a cut.
166 bool isLogicGate() const {
167 Kind k = getKind();
168 return k == And2 || k == Xor2 || k == Maj3 || k == Identity;
169 }
170
171 /// Check if this should always be a cut input (PI or constant).
172 bool isAlwaysCutInput() const {
173 Kind k = getKind();
174 return k == PrimaryInput || k == Constant;
175 }
176};
177
178/// Flat logic network representation for efficient cut enumeration.
179///
180/// This class provides a mockturtle-style flat representation of the
181/// combinational logic network. Each value in the MLIR IR is assigned a unique
182/// index, and gates are stored in a contiguous vector for cache efficiency.
183///
184/// The network supports:
185/// - O(1) lookup of gate information by index
186/// - Compact representation with inversion encoded in edges
187/// - Efficient simulation and truth table computation
188///
189/// Special reserved indices:
190/// - Index 0: Constant 0
191/// - Index 1: Constant 1
193public:
194 /// Special constant indices.
195 static constexpr uint32_t kConstant0 = 0;
196 static constexpr uint32_t kConstant1 = 1;
197
198 ArrayRef<LogicNetworkGate> getGates() const { return gates; }
199
201 // Reserve index 0 for constant 0 and index 1 for constant 1
202 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
203 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
204 // indexToValue needs placeholders for constants
205 indexToValue.push_back(Value()); // const0
206 indexToValue.push_back(Value()); // const1
207 }
208
209 /// Get a LogicEdge representing constant 0.
210 static Signal getConstant0() { return Signal(kConstant0, false); }
211
212 /// Get a LogicEdge representing constant 1 (constant 0 inverted).
213 static Signal getConstant1() { return Signal(kConstant0, true); }
214
215 /// Get or create an index for a value.
216 /// If the value doesn't have an index yet, assigns one and returns the index.
217 uint32_t getOrCreateIndex(Value value);
218
219 /// Get the raw index for a value. Asserts if value is not found.
220 /// Note: This returns only the index, not a Signal with inversion info.
221 /// Use hasIndex() to check existence first, or use getOrCreateIndex().
222 uint32_t getIndex(Value value) const;
223
224 /// Check if a value has been indexed.
225 bool hasIndex(Value value) const;
226
227 /// Get the value for a given raw index. Asserts if index is out of bounds.
228 /// Returns null Value for constant indices (0 and 1).
229 Value getValue(uint32_t index) const;
230
231 /// Fill values for the given raw indices.
232 void getValues(ArrayRef<uint32_t> indices,
233 SmallVectorImpl<Value> &values) const;
234
235 /// Get a Signal for a value.
236 /// Asserts if value not found - use hasIndex() first if unsure.
237 Signal getSignal(Value value, bool inverted) const {
238 return Signal(getIndex(value), inverted);
239 }
240
241 /// Get or create a Signal for a value.
242 Signal getOrCreateSignal(Value value, bool inverted) {
243 return Signal(getOrCreateIndex(value), inverted);
244 }
245
246 /// Get the value for a given Signal (extracts index from Signal).
247 Value getValue(Signal signal) const { return getValue(signal.getIndex()); }
248
249 /// Get the gate at a given index.
250 const LogicNetworkGate &getGate(uint32_t index) const { return gates[index]; }
251
252 /// Get mutable reference to gate at index.
253 LogicNetworkGate &getGate(uint32_t index) { return gates[index]; }
254
255 /// Get the total number of nodes in the network.
256 size_t size() const { return gates.size(); }
257
258 /// Add a primary input to the network.
259 uint32_t addPrimaryInput(Value value);
260
261 /// Add a gate with explicit result value and operand signals.
262 uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result,
263 llvm::ArrayRef<Signal> operands = {});
264
265 /// Add a gate using op->getResult(0) as the result value.
266 uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind,
267 llvm::ArrayRef<Signal> operands = {}) {
268 return addGate(op, kind, op->getResult(0), operands);
269 }
270
271 /// Build the logic network from a region/block in topological order.
272 /// Returns failure if the IR is not in a valid form.
273 LogicalResult buildFromBlock(Block *block);
274
275 /// Clear the network and reset to initial state.
276 void clear();
277
278private:
279 /// Map from MLIR Value to network index.
280 llvm::DenseMap<Value, uint32_t> valueToIndex;
281
282 /// Map from network index to MLIR Value.
283 llvm::SmallVector<Value> indexToValue;
284
285 /// Vector of all gates in the network.
286 llvm::SmallVector<LogicNetworkGate> gates;
287};
288
289/// Result of matching a cut against a pattern.
290///
291/// This structure contains the area and per-input delay information
292/// computed during pattern matching.
293///
294/// The delays can be stored in two ways:
295/// 1. As a reference to static/cached data (e.g., tech library delays)
296/// - Use setDelayRef() for zero-cost reference (no allocation)
297/// 2. As owned dynamic data (e.g., computed SOP delays)
298/// - Use setOwnedDelays() to transfer ownership
299///
301 /// Area cost of implementing this cut with the pattern.
302 double area = 0.0;
303
304 /// Default constructor.
305 MatchResult() = default;
306
307 /// Constructor with area and delays (by reference).
308 MatchResult(double area, ArrayRef<DelayType> delays)
309 : area(area), borrowedDelays(delays) {}
310
311 /// Set delays by reference (zero-cost for static/cached delays).
312 /// The caller must ensure the referenced data remains valid.
313 void setDelayRef(ArrayRef<DelayType> delays) { borrowedDelays = delays; }
314
315 /// Set delays by transferring ownership (for dynamically computed delays).
316 /// This moves the data into internal storage.
317 void setOwnedDelays(SmallVector<DelayType, 6> delays) {
318 ownedDelays.emplace(std::move(delays));
319 borrowedDelays = {};
320 }
321
322 /// Get all delays as an ArrayRef.
323 ArrayRef<DelayType> getDelays() const {
324 return ownedDelays.has_value() ? ArrayRef<DelayType>(*ownedDelays)
326 }
327
328private:
329 /// Borrowed delays (used when ownedDelays is empty).
330 /// Points to external data provided via setDelayRef().
331 ArrayRef<DelayType> borrowedDelays;
332
333 /// Owned delays (used when present).
334 /// Only allocated when setOwnedDelays() is called. When empty (std::nullopt),
335 /// moving this MatchResult avoids constructing/moving the SmallVector,
336 /// achieving zero-cost abstraction for the common case (borrowed delays).
337 std::optional<SmallVector<DelayType, 6>> ownedDelays;
338};
339
340/// Represents a cut that has been successfully matched to a rewriting pattern.
341///
342/// This class encapsulates the result of matching a cut against a rewriting
343/// pattern during optimization. It stores the matched pattern, the
344/// cut that was matched, and timing information needed for optimization.
346private:
347 const CutRewritePattern *pattern = nullptr; ///< The matched library pattern
348 SmallVector<DelayType, 1>
349 arrivalTimes; ///< Arrival times of outputs from this pattern
350 double area = 0.0; ///< Area cost of this pattern
351
352public:
353 /// Default constructor creates an invalid matched pattern.
354 MatchedPattern() = default;
355
356 /// Constructor for a valid matched pattern.
358 SmallVector<DelayType, 1> arrivalTimes, double area)
359 : pattern(pattern), arrivalTimes(std::move(arrivalTimes)), area(area) {}
360
361 /// Get the arrival time of signals through this pattern.
362 DelayType getArrivalTime(unsigned outputIndex) const;
363 ArrayRef<DelayType> getArrivalTimes() const;
364
365 /// Get the library pattern that was matched.
366 const CutRewritePattern *getPattern() const;
367
368 /// Get the area cost of using this pattern.
369 double getArea() const;
370};
371
372/// Represents a cut in the combinational logic network.
373///
374/// A cut is a subset of nodes in the combinational logic that forms a complete
375/// subgraph with a single output. It represents a portion of the circuit that
376/// can potentially be replaced with a single library gate or pattern.
377///
378/// The cut contains:
379/// - Input values: The boundary between the cut and the rest of the circuit
380/// - Operations: The logic operations within the cut boundary
381/// - Root operation: The output-driving operation of the cut
382///
383/// Cuts are used in combinational logic optimization to identify regions that
384/// can be optimized and replaced with more efficient implementations.
385class Cut {
386 /// Cached truth table for this cut.
387 /// Computed lazily when first accessed to avoid unnecessary computation.
388 mutable std::optional<BinaryTruthTable> truthTable;
389
390 /// Cached NPN canonical form for this cut.
391 /// Computed lazily from the truth table when first accessed.
392 mutable std::optional<NPNClass> npnClass;
393
394 std::optional<MatchedPattern> matchedPattern;
395
396 /// Root index in LogicNetwork (0 indicates no root for a trivial cut).
397 /// The root node produces the output of the cut.
398 uint32_t rootIndex = 0;
399
400 /// Operand cuts used to create this cut (for lazy TT computation).
401 /// Stored to enable fast incremental truth table computation after
402 /// duplicate removal. Using raw pointers is safe since cuts are allocated
403 /// via bump allocator and live for the duration of enumeration.
404 llvm::SmallVector<const Cut *, 3> operandCuts;
405
406public:
407 /// External inputs to this cut (cut boundary).
408 /// Stored as LogicNetwork indices for efficient operations.
409 llvm::SmallVector<uint32_t, 6> inputs;
410
411 /// Check if this cut represents a trivial cut.
412 /// A trivial cut has no root operation and exactly one input.
413 bool isTrivialCut() const;
414
415 /// Get the root index in the LogicNetwork.
416 uint32_t getRootIndex() const { return rootIndex; }
417
418 /// Set the root index of this cut.
419 void setRootIndex(uint32_t idx) { rootIndex = idx; }
420
421 /// Check if this cut dominates another cut.
422 bool dominates(const Cut &other) const;
423
424 /// Check if this cut dominates another sorted input set.
425 bool dominates(ArrayRef<uint32_t> otherInputs) const;
426
427 void dump(llvm::raw_ostream &os, const LogicNetwork &network) const;
428
429 /// Get the number of inputs to this cut.
430 unsigned getInputSize() const;
431
432 /// Get the number of outputs from root operation.
433 unsigned getOutputSize(const LogicNetwork &network) const;
434
435 /// Get the truth table for this cut.
436 /// The truth table represents the boolean function computed by this cut.
437 const std::optional<BinaryTruthTable> &getTruthTable() const {
438 return truthTable;
439 }
440
441 /// Compute and cache the truth table for this cut using the LogicNetwork.
442 void computeTruthTable(const LogicNetwork &network);
443
444 /// Compute truth table using fast incremental method from operand cuts.
445 /// This is much faster than simulation-based computation.
446 /// Requires that operand cuts have already been set via setOperandCuts.
447 void computeTruthTableFromOperands(const LogicNetwork &network);
448
449 /// Set the truth table directly (used for incremental computation).
450 void setTruthTable(BinaryTruthTable tt) { truthTable.emplace(std::move(tt)); }
451
452 /// Set operand cuts for lazy truth table computation.
453 void setOperandCuts(ArrayRef<const Cut *> cuts) {
454 operandCuts.assign(cuts.begin(), cuts.end());
455 }
456
457 /// Get operand cuts (for fast TT computation).
458 ArrayRef<const Cut *> getOperandCuts() const { return operandCuts; }
459
460 /// Get the NPN canonical form for this cut.
461 /// This is used for efficient pattern matching against library components.
462 const NPNClass &getNPNClass() const;
463
464 /// Get the permutated inputs for this cut based on the given pattern NPN.
465 /// Returns indices into the inputs vector.
466 void
467 getPermutatedInputIndices(const NPNClass &patternNPN,
468 SmallVectorImpl<unsigned> &permutedIndices) const;
469
470 /// Get arrival times for each input of this cut.
471 /// Returns failure if any input doesn't have a valid matched pattern.
472 LogicalResult getInputArrivalTimes(CutEnumerator &enumerator,
473 SmallVectorImpl<DelayType> &results) const;
474
475 /// Matched pattern for this cut.
479
480 /// Get the matched pattern for this cut.
481 const std::optional<MatchedPattern> &getMatchedPattern() const {
482 return matchedPattern;
483 }
484};
485
486/// Manages a collection of cuts for a single logic node using priority cuts
487/// algorithm.
488///
489/// Each node in the combinational logic network can have multiple cuts
490/// representing different ways to group it with surrounding logic. The CutSet
491/// manages these cuts and selects the best one based on the optimization
492/// strategy (area or timing).
493///
494/// The priority cuts algorithm maintains a bounded set of the most promising
495/// cuts to avoid exponential explosion while ensuring good optimization
496/// results.
497class CutSet {
498private:
499 llvm::SmallVector<Cut *, 12> cuts; ///< Collection of cuts for this node
500 Cut *bestCut = nullptr;
501 bool isFrozen = false; ///< Whether cut set is finalized
502
503public:
504 /// Check if this cut set has a valid matched pattern.
505 bool isMatched() const { return bestCut; }
506
507 /// Get the cut associated with the best matched pattern.
508 Cut *getBestMatchedCut() const;
509
510 /// Finalize the cut set by removing duplicates and selecting the best
511 /// pattern.
512 void finalize(
513 const CutRewriterOptions &options,
514 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut,
515 const LogicNetwork &logicNetwork);
516
517 /// Get the number of cuts in this set.
518 unsigned size() const;
519
520 /// Add a new cut to this set using bump allocator.
521 /// NOTE: The cut set must not be frozen
522 void addCut(Cut *cut);
523
524 /// Get read-only access to all cuts in this set.
525 ArrayRef<Cut *> getCuts() const;
526};
527
528/// Configuration options for the cut-based rewriting algorithm.
529///
530/// These options control various aspects of the rewriting process including
531/// optimization strategy, resource limits, and algorithmic parameters.
533 /// Optimization strategy (area vs. timing).
535
536 /// Maximum number of inputs allowed for any cut.
537 /// Larger cuts provide more optimization opportunities but increase
538 /// computational complexity exponentially.
540
541 /// Maximum number of cuts to maintain per logic node.
542 /// The priority cuts algorithm keeps only the most promising cuts
543 /// to prevent exponential explosion.
545
546 /// Fail if there is a root operation that has no matching pattern.
547 bool allowNoMatch = false;
548
549 /// Put arrival times to rewritten operations.
550 bool attachDebugTiming = false;
551
552 /// Run priority cuts enumeration and dump the cut sets.
553 bool testPriorityCuts = false;
554};
555
556//===----------------------------------------------------------------------===//
557// Cut Enumeration Engine
558//===----------------------------------------------------------------------===//
559
561 uint64_t numCutsCreated = 0;
562 uint64_t numCutSetsCreated = 0;
563 uint64_t numCutsRewritten = 0;
564};
565
566template <typename T>
568public:
571
572 template <typename... Args>
573 T *create(Args &&...args) {
575 return new (allocator.Allocate()) T(std::forward<Args>(args)...);
576 }
577
578 void DestroyAll() { allocator.DestroyAll(); }
579
580private:
581 llvm::SpecificBumpPtrAllocator<T> allocator;
583};
584/// Cut enumeration engine for combinational logic networks.
585///
586/// The CutEnumerator is responsible for generating cuts for each node in a
587/// combinational logic network. It uses a priority cuts algorithm to maintain a
588/// bounded set of promising cuts while avoiding exponential explosion.
589///
590/// The enumeration process works by:
591/// 1. Visiting nodes in topological order
592/// 2. For each node, combining cuts from its inputs
593/// 3. Matching generated cuts against available patterns
594/// 4. Maintaining only the most promising cuts per node
596public:
597 /// Constructor for cut enumerator.
599
600 /// Enumerate cuts for all nodes in the given module.
601 ///
602 /// This is the main entry point that orchestrates the cut enumeration
603 /// process. It visits all operations in the module and generates cuts
604 /// for combinational logic operations.
605 LogicalResult enumerateCuts(
606 Operation *topOp,
607 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut =
608 [](const Cut &) { return std::nullopt; });
609
610 /// Create a new cut set for an index.
611 /// The index must not already have a cut set.
612 CutSet *createNewCutSet(uint32_t index);
613
614 /// Get the cut set for a specific index.
615 /// If not found, it means no cuts have been generated for this value yet.
616 /// In that case return a trivial cut set.
617 const CutSet *getCutSet(uint32_t index);
618
619 /// Clear all cut sets and reset the enumerator.
620 void clear();
621
622 /// Record that one cut was successfully rewritten.
624
625 void dump() const;
626
627 /// Get cut sets (indexed by LogicNetwork index).
628 const llvm::DenseMap<uint32_t, CutSet *> &getCutSets() const {
629 return cutSets;
630 }
631
632 /// Get the processing order.
633 ArrayRef<uint32_t> getProcessingOrder() const { return processingOrder; }
634
635private:
636 /// Visit a combinational logic operation and generate cuts.
637 /// This handles the core cut enumeration logic for operations
638 /// like AND, OR, XOR, etc.
639 LogicalResult visitLogicOp(uint32_t nodeIndex);
640
641 /// Maps indices to their associated cut sets.
642 /// CutSets are allocated from the bump allocator.
643 llvm::DenseMap<uint32_t, CutSet *> cutSets;
644
645 /// Typed bump allocators for fast allocation with destructors.
648
649 /// Indices in processing order.
650 llvm::SmallVector<uint32_t> processingOrder;
651
652 /// Configuration options for cut enumeration.
654
655 /// Function to match cuts against available patterns.
656 /// Set during enumeration and used when finalizing cut sets.
657 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut;
658
659 /// Flat logic network representation used during enumeration/rewrite.
661
662 /// Statistics for cut enumeration (number of cuts allocated, etc.).
664
665public:
666 /// Get the logic network (read-only).
667 const LogicNetwork &getLogicNetwork() const { return logicNetwork; }
668
669 /// Get the logic network (mutable).
671
672 /// Get enumeration statistics.
673 const CutEnumeratorStats &getStats() const { return stats; }
674};
675
676/// Base class for cut rewriting patterns used in combinational logic
677/// optimization.
678///
679/// A CutRewritePattern represents a library component or optimization pattern
680/// that can replace cuts in the combinational logic network. Each pattern
681/// defines:
682/// - How to recognize matching cuts and compute area/delay metrics
683/// - How to transform/replace the matched cuts
684///
685/// Patterns can use truth table matching for efficient recognition or
686/// implement custom matching logic for more complex cases.
688 CutRewritePattern(mlir::MLIRContext *context) : context(context) {}
689 /// Virtual destructor for base class.
690 virtual ~CutRewritePattern() = default;
691
692 /// Check if a cut matches this pattern and compute area/delay metrics.
693 ///
694 /// This method is called to determine if a cut can be replaced by this
695 /// pattern. If the cut matches, it should return a MatchResult containing
696 /// the area and per-input delays for this specific cut.
697 ///
698 /// If useTruthTableMatcher() returns true, this method is only
699 /// called for cuts with matching truth tables.
700 virtual std::optional<MatchResult> match(CutEnumerator &enumerator,
701 const Cut &cut) const = 0;
702
703 /// Specify truth tables that this pattern can match.
704 ///
705 /// If this method returns true, the pattern matcher will use truth table
706 /// comparison for efficient pre-filtering. Only cuts with matching truth
707 /// tables will be passed to the match() method. If it returns false, the
708 /// pattern will be checked against all cuts regardless of their truth tables.
709 /// This is useful for patterns that match regardless of their truth tables,
710 /// such as LUT-based patterns.
711 virtual bool
712 useTruthTableMatcher(SmallVectorImpl<NPNClass> &matchingNPNClasses) const;
713
714 /// Return a new operation that replaces the matched cut.
715 ///
716 /// Unlike MLIR's RewritePattern framework which allows arbitrary in-place
717 /// modifications, this method creates a new operation to replace the matched
718 /// cut rather than modifying existing operations. This constraint exists
719 /// because the cut enumerator maintains references to operations throughout
720 /// the circuit, making it safe to only replace the root operation of each
721 /// cut while preserving all other operations unchanged.
722 virtual FailureOr<Operation *> rewrite(mlir::OpBuilder &builder,
723 CutEnumerator &enumerator,
724 const Cut &cut) const = 0;
725
726 /// Get the number of outputs this pattern produces.
727 virtual unsigned getNumOutputs() const = 0;
728
729 /// Get the name of this pattern. Used for debugging.
730 virtual StringRef getPatternName() const { return "<unnamed>"; }
731
732 /// Get location for this pattern(optional).
733 virtual LocationAttr getLoc() const { return mlir::UnknownLoc::get(context); }
734
735 mlir::MLIRContext *getContext() const { return context; }
736
737private:
738 mlir::MLIRContext *context;
739};
740
741/// Manages a collection of rewriting patterns for combinational logic
742/// optimization.
743///
744/// This class organizes and provides efficient access to rewriting patterns
745/// used during cut-based optimization. It maintains:
746/// - A collection of all available patterns
747/// - Fast lookup tables for truth table-based matching
748/// - Separation of truth table vs. custom matching patterns
749///
750/// The pattern set is used by the CutRewriter to find suitable replacements
751/// for cuts in the combinational logic network.
753public:
754 /// Constructor that takes ownership of the provided patterns.
755 ///
756 /// During construction, patterns are analyzed and organized for efficient
757 /// lookup. Truth table matchers are indexed by their NPN canonical forms.
759 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns);
760
761private:
762 /// Owned collection of all rewriting patterns.
763 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns;
764
765 /// Fast lookup table mapping NPN canonical forms to matching patterns.
766 /// Each entry maps a truth table and input size to patterns that can handle
767 /// it.
768 DenseMap<std::pair<APInt, unsigned>,
769 SmallVector<std::pair<NPNClass, const CutRewritePattern *>>>
771
772 /// Patterns that use custom matching logic instead of NPN lookup.
773 /// These patterns are checked against every cut.
774 SmallVector<const CutRewritePattern *, 4> nonNPNPatterns;
775
776 /// CutRewriter needs access to internal data structures for pattern matching.
777 friend class CutRewriter;
778};
779
780/// Main cut-based rewriting algorithm for combinational logic optimization.
781///
782/// The CutRewriter implements a cut-based rewriting algorithm that:
783/// 1. Enumerates cuts in the combinational logic network using a priority cuts
784/// algorithm
785/// 2. Matches cuts against available rewriting patterns
786/// 3. Selects optimal patterns based on area or timing objectives
787/// 4. Rewrites the circuit using the selected patterns
788///
789/// The algorithm processes the network in topological order, building up cut
790/// sets for each node and selecting the best implementation based on the
791/// specified optimization strategy.
792///
793/// Usage example:
794/// ```cpp
795/// CutRewriterOptions options;
796/// options.strategy = OptimizationStrategy::Area;
797/// options.maxCutInputSize = 4;
798/// options.maxCutSizePerRoot = 8;
799///
800/// CutRewritePatternSet patterns(std::move(optimizationPatterns));
801/// CutRewriter rewriter(module, options, patterns);
802///
803/// if (failed(rewriter.run())) {
804/// // Handle rewriting failure
805/// }
806/// ```
808public:
809 /// Constructor for the cut rewriter.
812
813 /// Execute the complete cut-based rewriting algorithm.
814 ///
815 /// This method orchestrates the entire rewriting process:
816 /// 1. Enumerate cuts for all nodes in the combinational logic
817 /// 2. Match cuts against available patterns
818 /// 3. Select optimal patterns based on strategy
819 /// 4. Rewrite the circuit with selected patterns
820 LogicalResult run(Operation *topOp);
821
823 return cutEnumerator.getStats();
824 }
825
826private:
827 /// Enumerate cuts for all nodes in the given module.
828 /// Note: This preserves module boundaries and does not perform
829 /// rewriting across the hierarchy.
830 LogicalResult enumerateCuts(Operation *topOp);
831
832 /// Find patterns that match a cut's truth table.
833 ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
834 getMatchingPatternsFromTruthTable(const Cut &cut) const;
835
836 /// Match a cut against available patterns and compute arrival time.
837 std::optional<MatchedPattern> patternMatchCut(const Cut &cut);
838
839 /// Perform the actual circuit rewriting using selected patterns.
840 LogicalResult runBottomUpRewrite(Operation *topOp);
841
842 /// Configuration options
844
845 /// Available rewriting patterns
847
849};
850
851} // namespace synth
852} // namespace circt
853
854#endif // CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
assert(baseType &&"element must be base type")
RewritePatternSet pattern
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
LogicNetwork & getLogicNetwork()
Get the logic network (mutable).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
TrackedSpecificBumpPtrAllocator< CutSet > cutSetAllocator
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
CutEnumeratorStats stats
Statistics for cut enumeration (number of cuts allocated, etc.).
void noteCutRewritten()
Record that one cut was successfully rewritten.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
const CutEnumeratorStats & getStats() const
Get enumeration statistics.
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
TrackedSpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
Manages a collection of rewriting patterns for combinational logic optimization.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
Main cut-based rewriting algorithm for combinational logic optimization.
const CutEnumeratorStats & getStats() const
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
const CutRewritePatternSet & patterns
Available rewriting patterns.
CutEnumerator cutEnumerator
CutRewriter(const CutRewriterOptions &options, CutRewritePatternSet &patterns)
Constructor for the cut rewriter.
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
bool isMatched() const
Check if this cut set has a valid matched pattern.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
void setRootIndex(uint32_t idx)
Set the root index of this cut.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
void getPermutatedInputIndices(const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
void setOperandCuts(ArrayRef< const Cut * > cuts)
Set operand cuts for lazy truth table computation.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another cut.
unsigned getInputSize() const
Get the number of inputs to this cut.
ArrayRef< const Cut * > getOperandCuts() const
Get operand cuts (for fast TT computation).
void computeTruthTable(const LogicNetwork &network)
Compute and cache the truth table for this cut using the LogicNetwork.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
Value getValue(Signal signal) const
Get the value for a given Signal (extracts index from Signal).
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
LogicNetworkGate & getGate(uint32_t index)
Get mutable reference to gate at index.
Signal getSignal(Value value, bool inverted) const
Get a Signal for a value.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, llvm::ArrayRef< Signal > operands={})
Add a gate using op->getResult(0) as the result value.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
static Signal getConstant0()
Get a LogicEdge representing constant 0.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
size_t size() const
Get the total number of nodes in the network.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
static Signal getConstant1()
Get a LogicEdge representing constant 1 (constant 0 inverted).
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
MatchedPattern(const CutRewritePattern *pattern, SmallVector< DelayType, 1 > arrivalTimes, double area)
Constructor for a valid matched pattern.
double getArea() const
Get the area cost of using this pattern.
MatchedPattern()=default
Default constructor creates an invalid matched pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
llvm::SpecificBumpPtrAllocator< T > allocator
TrackedSpecificBumpPtrAllocator(uint64_t &allocationCount)
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
int64_t DelayType
Definition CutRewriter.h:40
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:45
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
Represents a boolean function as a truth table.
Definition TruthTable.h:39
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:104
Base class for cut rewriting patterns used in combinational logic optimization.
virtual StringRef getPatternName() const
Get the name of this pattern. Used for debugging.
virtual FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const =0
Return a new operation that replaces the matched cut.
virtual ~CutRewritePattern()=default
Virtual destructor for base class.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
virtual std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const =0
Check if a cut matches this pattern and compute area/delay metrics.
mlir::MLIRContext * getContext() const
mlir::MLIRContext * context
CutRewritePattern(mlir::MLIRContext *context)
virtual unsigned getNumOutputs() const =0
Get the number of outputs this pattern produces.
virtual LocationAttr getLoc() const
Get location for this pattern(optional).
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
LogicNetworkGate(Operation *op, Kind kind, llvm::ArrayRef< Signal > operands={})
unsigned getNumFanins() const
Get the number of fanin edges based on kind.
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
llvm::PointerIntPair< Operation *, 3, Kind > opAndKind
Operation pointer and kind packed together.
bool isLogicGate() const
Check if this is a logic gate that can be part of a cut.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Identity
Identity gate (used for 1-input inverter)
@ Maj3
Majority gate (3-input, mig::MajOp)
@ PrimaryInput
Primary input to the network.
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
void setDelayRef(ArrayRef< DelayType > delays)
Set delays by reference (zero-cost for static/cached delays).
MatchResult()=default
Default constructor.
void setOwnedDelays(SmallVector< DelayType, 6 > delays)
Set delays by transferring ownership (for dynamically computed delays).
double area
Area cost of implementing this cut with the pattern.
ArrayRef< DelayType > borrowedDelays
Borrowed delays (used when ownedDelays is empty).
std::optional< SmallVector< DelayType, 6 > > ownedDelays
Owned delays (used when present).
MatchResult(double area, ArrayRef< DelayType > delays)
Constructor with area and delays (by reference).
ArrayRef< DelayType > getDelays() const
Get all delays as an ArrayRef.
Edge representation in the logic network.
Definition CutRewriter.h:75
uint32_t getIndex() const
Get the node index (without the inversion bit).
Definition CutRewriter.h:84
Signal operator!() const
Create an inverted version of this edge.
Definition CutRewriter.h:95
bool operator<(const Signal &other) const
Definition CutRewriter.h:99
bool operator!=(const Signal &other) const
Definition CutRewriter.h:98
Signal flipInversion() const
Definition CutRewriter.h:92
uint32_t getRaw() const
Get the raw data (index << 1 | inverted).
Definition CutRewriter.h:90
bool isInverted() const
Check if this edge is inverted.
Definition CutRewriter.h:87
bool operator==(const Signal &other) const
Definition CutRewriter.h:97
Signal(uint32_t index, bool inverted)
Definition CutRewriter.h:79
Signal(uint32_t raw)
Definition CutRewriter.h:81