CIRCT 23.0.0git
Loading...
Searching...
No Matches
CutRewriter.h
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines a general cut-based rewriting framework for
10// combinational logic optimization. The framework uses NPN-equivalence matching
11// with area and delay metrics to rewrite cuts (subgraphs) in combinational
12// circuits with optimal patterns.
13//
14//===----------------------------------------------------------------------===//
15
16#ifndef CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
17#define CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
18
20#include "circt/Support/LLVM.h"
22#include "mlir/IR/Operation.h"
23#include "llvm/ADT/APInt.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/PointerIntPair.h"
26#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Allocator.h"
29#include "llvm/Support/LogicalResult.h"
30#include "llvm/Support/raw_ostream.h"
31#include <memory>
32#include <optional>
33
34namespace circt {
35namespace synth {
36// Type for representing delays in the circuit. It's user's responsibility to
37// use consistent units, i.e., all delays should be in the same unit (usually
38// femtoseconds, but not limited to it).
39using DelayType = int64_t;
40
41/// Maximum number of inputs supported for truth table generation.
42/// This limit prevents excessive memory usage as truth table size grows
43/// exponentially with the number of inputs (2^n entries).
44static constexpr unsigned maxTruthTableInputs = 16;
45
46// This is a helper function to sort operations topologically in a logic
47// network. This is necessary for cut rewriting to ensure that operations are
48// processed in the correct order, respecting dependencies.
49LogicalResult topologicallySortLogicNetwork(mlir::Operation *op);
50
51// Get the truth table for a specific operation within a block.
52// Block must be a SSACFG or topologically sorted.
53FailureOr<BinaryTruthTable> getTruthTable(ValueRange values, Block *block);
54
55//===----------------------------------------------------------------------===//
56// Cut Data Structures
57//===----------------------------------------------------------------------===//
58
59// Forward declarations
61class CutRewriter;
62class CutEnumerator;
65class LogicNetwork;
66
67//===----------------------------------------------------------------------===//
68// Logic Network Data Structures (Flat IR for efficient cut enumeration)
69//===----------------------------------------------------------------------===//
70
71/// Edge representation in the logic network.
72/// Similar to mockturtle's signal, this encodes both a node index and inversion
73/// in a single 32-bit value. The LSB indicates whether the signal is inverted.
74struct Signal {
75 uint32_t data = 0;
76
77 Signal() = default;
78 Signal(uint32_t index, bool inverted)
79 : data((index << 1) | (inverted ? 1 : 0)) {}
80 explicit Signal(uint32_t raw) : data(raw) {}
81
82 /// Get the node index (without the inversion bit).
83 uint32_t getIndex() const { return data >> 1; }
84
85 /// Check if this edge is inverted.
86 bool isInverted() const { return data & 1; }
87
88 /// Get the raw data (index << 1 | inverted).
89 uint32_t getRaw() const { return data; }
90
91 Signal flipInversion() const { return Signal(getIndex(), !isInverted()); }
92
93 /// Create an inverted version of this edge.
94 Signal operator!() const { return Signal(data ^ 1); }
95
96 bool operator==(const Signal &other) const { return data == other.data; }
97 bool operator!=(const Signal &other) const { return data != other.data; }
98 bool operator<(const Signal &other) const { return data < other.data; }
99};
100
101/// Represents a single gate/node in the flat logic network.
102/// This structure is designed to be cache-friendly and supports up to 3 inputs
103/// (sufficient for AND, XOR, MAJ gates). For nodes with fewer inputs, unused
104/// edges have index 0 (constant 0 node).
105///
106/// Special indices:
107/// - Index 0: Constant 0
108/// - Index 1: Constant 1
109///
110/// It uses 8 bytes for operation pointer + enum, 12 bytes for edges = 20
111/// bytes per gate.
113 /// Kind of logic gate.
114 enum Kind : uint8_t {
115 Constant = 0, ///< Constant 0/1 node (index 0 = const0, index 1 = const1)
116 PrimaryInput = 1, ///< Primary input to the network
117 And2 = 2, ///< AND gate (2-input, aig::AndInverterOp)
118 Xor2 = 3, ///< XOR gate (2-input)
119 Maj3 = 4, ///< Majority gate (3-input, mig::MajOp)
120 Identity = 5 ///< Identity gate (used for 1-input inverter)
121 };
122
123 /// Operation pointer and kind packed together.
124 /// The kind is stored in the low bits of the pointer.
125 llvm::PointerIntPair<Operation *, 3, Kind> opAndKind;
126
127 /// Fanin edges (up to 3 inputs). For AND gates, only edges[0] and edges[1]
128 /// are used. For MAJ gates, all three are used. For PrimaryInput/Constant,
129 /// none are used. The inversion bit is encoded in each edge.
131
133 LogicNetworkGate(Operation *op, Kind kind,
134 llvm::ArrayRef<Signal> operands = {})
135 : opAndKind(op, kind), edges{} {
136 assert(operands.size() <= 3 && "Too many operands for LogicNetworkGate");
137 for (size_t i = 0; i < operands.size(); ++i)
138 edges[i] = operands[i];
139 }
140
141 /// Get the kind of this gate.
142 Kind getKind() const { return opAndKind.getInt(); }
143
144 /// Get the operation pointer (nullptr for constants).
145 Operation *getOperation() const { return opAndKind.getPointer(); }
146
147 /// Get the number of fanin edges based on kind.
148 unsigned getNumFanins() const {
149 switch (getKind()) {
150 case Constant:
151 case PrimaryInput:
152 return 0;
153 case And2:
154 case Xor2:
155 return 2;
156 case Maj3:
157 return 3;
158 case Identity:
159 return 1;
160 }
161 llvm_unreachable("Unknown gate kind");
162 }
163
164 /// Check if this is a logic gate that can be part of a cut.
165 bool isLogicGate() const {
166 Kind k = getKind();
167 return k == And2 || k == Xor2 || k == Maj3 || k == Identity;
168 }
169
170 /// Check if this should always be a cut input (PI or constant).
171 bool isAlwaysCutInput() const {
172 Kind k = getKind();
173 return k == PrimaryInput || k == Constant;
174 }
175};
176
177/// Flat logic network representation for efficient cut enumeration.
178///
179/// This class provides a mockturtle-style flat representation of the
180/// combinational logic network. Each value in the MLIR IR is assigned a unique
181/// index, and gates are stored in a contiguous vector for cache efficiency.
182///
183/// The network supports:
184/// - O(1) lookup of gate information by index
185/// - Compact representation with inversion encoded in edges
186/// - Efficient simulation and truth table computation
187///
188/// Special reserved indices:
189/// - Index 0: Constant 0
190/// - Index 1: Constant 1
192public:
193 /// Special constant indices.
194 static constexpr uint32_t kConstant0 = 0;
195 static constexpr uint32_t kConstant1 = 1;
196
197 ArrayRef<LogicNetworkGate> getGates() const { return gates; }
198
200 // Reserve index 0 for constant 0 and index 1 for constant 1
201 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
202 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
203 // indexToValue needs placeholders for constants
204 indexToValue.push_back(Value()); // const0
205 indexToValue.push_back(Value()); // const1
206 }
207
208 /// Get a LogicEdge representing constant 0.
209 static Signal getConstant0() { return Signal(kConstant0, false); }
210
211 /// Get a LogicEdge representing constant 1 (constant 0 inverted).
212 static Signal getConstant1() { return Signal(kConstant0, true); }
213
214 /// Get or create an index for a value.
215 /// If the value doesn't have an index yet, assigns one and returns the index.
216 uint32_t getOrCreateIndex(Value value);
217
218 /// Get the raw index for a value. Asserts if value is not found.
219 /// Note: This returns only the index, not a Signal with inversion info.
220 /// Use hasIndex() to check existence first, or use getOrCreateIndex().
221 uint32_t getIndex(Value value) const;
222
223 /// Check if a value has been indexed.
224 bool hasIndex(Value value) const;
225
226 /// Get the value for a given raw index. Asserts if index is out of bounds.
227 /// Returns null Value for constant indices (0 and 1).
228 Value getValue(uint32_t index) const;
229
230 /// Fill values for the given raw indices.
231 void getValues(ArrayRef<uint32_t> indices,
232 SmallVectorImpl<Value> &values) const;
233
234 /// Get a Signal for a value.
235 /// Asserts if value not found - use hasIndex() first if unsure.
236 Signal getSignal(Value value, bool inverted) const {
237 return Signal(getIndex(value), inverted);
238 }
239
240 /// Get or create a Signal for a value.
241 Signal getOrCreateSignal(Value value, bool inverted) {
242 return Signal(getOrCreateIndex(value), inverted);
243 }
244
245 /// Get the value for a given Signal (extracts index from Signal).
246 Value getValue(Signal signal) const { return getValue(signal.getIndex()); }
247
248 /// Get the gate at a given index.
249 const LogicNetworkGate &getGate(uint32_t index) const { return gates[index]; }
250
251 /// Get mutable reference to gate at index.
252 LogicNetworkGate &getGate(uint32_t index) { return gates[index]; }
253
254 /// Get the total number of nodes in the network.
255 size_t size() const { return gates.size(); }
256
257 /// Add a primary input to the network.
258 uint32_t addPrimaryInput(Value value);
259
260 /// Add a gate with explicit result value and operand signals.
261 uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result,
262 llvm::ArrayRef<Signal> operands = {});
263
264 /// Add a gate using op->getResult(0) as the result value.
265 uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind,
266 llvm::ArrayRef<Signal> operands = {}) {
267 return addGate(op, kind, op->getResult(0), operands);
268 }
269
270 /// Build the logic network from a region/block in topological order.
271 /// Returns failure if the IR is not in a valid form.
272 LogicalResult buildFromBlock(Block *block);
273
274 /// Clear the network and reset to initial state.
275 void clear();
276
277private:
278 /// Map from MLIR Value to network index.
279 llvm::DenseMap<Value, uint32_t> valueToIndex;
280
281 /// Map from network index to MLIR Value.
282 llvm::SmallVector<Value> indexToValue;
283
284 /// Vector of all gates in the network.
285 llvm::SmallVector<LogicNetworkGate> gates;
286};
287
288/// Result of matching a cut against a pattern.
289///
290/// This structure contains the area and per-input delay information
291/// computed during pattern matching.
292///
293/// The delays can be stored in two ways:
294/// 1. As a reference to static/cached data (e.g., tech library delays)
295/// - Use setDelayRef() for zero-cost reference (no allocation)
296/// 2. As owned dynamic data (e.g., computed SOP delays)
297/// - Use setOwnedDelays() to transfer ownership
298///
300 /// Area cost of implementing this cut with the pattern.
301 double area = 0.0;
302
303 /// Default constructor.
304 MatchResult() = default;
305
306 /// Constructor with area and delays (by reference).
307 MatchResult(double area, ArrayRef<DelayType> delays)
308 : area(area), borrowedDelays(delays) {}
309
310 /// Set delays by reference (zero-cost for static/cached delays).
311 /// The caller must ensure the referenced data remains valid.
312 void setDelayRef(ArrayRef<DelayType> delays) { borrowedDelays = delays; }
313
314 /// Set delays by transferring ownership (for dynamically computed delays).
315 /// This moves the data into internal storage.
316 void setOwnedDelays(SmallVector<DelayType, 6> delays) {
317 ownedDelays.emplace(std::move(delays));
318 borrowedDelays = {};
319 }
320
321 /// Get all delays as an ArrayRef.
322 ArrayRef<DelayType> getDelays() const {
323 return ownedDelays.has_value() ? ArrayRef<DelayType>(*ownedDelays)
325 }
326
327private:
328 /// Borrowed delays (used when ownedDelays is empty).
329 /// Points to external data provided via setDelayRef().
330 ArrayRef<DelayType> borrowedDelays;
331
332 /// Owned delays (used when present).
333 /// Only allocated when setOwnedDelays() is called. When empty (std::nullopt),
334 /// moving this MatchResult avoids constructing/moving the SmallVector,
335 /// achieving zero-cost abstraction for the common case (borrowed delays).
336 std::optional<SmallVector<DelayType, 6>> ownedDelays;
337};
338
339/// Represents a cut that has been successfully matched to a rewriting pattern.
340///
341/// This class encapsulates the result of matching a cut against a rewriting
342/// pattern during optimization. It stores the matched pattern, the
343/// cut that was matched, and timing information needed for optimization.
345private:
346 const CutRewritePattern *pattern = nullptr; ///< The matched library pattern
347 SmallVector<DelayType, 1>
348 arrivalTimes; ///< Arrival times of outputs from this pattern
349 double area = 0.0; ///< Area cost of this pattern
350
351public:
352 /// Default constructor creates an invalid matched pattern.
353 MatchedPattern() = default;
354
355 /// Constructor for a valid matched pattern.
357 SmallVector<DelayType, 1> arrivalTimes, double area)
358 : pattern(pattern), arrivalTimes(std::move(arrivalTimes)), area(area) {}
359
360 /// Get the arrival time of signals through this pattern.
361 DelayType getArrivalTime(unsigned outputIndex) const;
362 ArrayRef<DelayType> getArrivalTimes() const;
363
364 /// Get the library pattern that was matched.
365 const CutRewritePattern *getPattern() const;
366
367 /// Get the area cost of using this pattern.
368 double getArea() const;
369};
370
371/// Represents a cut in the combinational logic network.
372///
373/// A cut is a subset of nodes in the combinational logic that forms a complete
374/// subgraph with a single output. It represents a portion of the circuit that
375/// can potentially be replaced with a single library gate or pattern.
376///
377/// The cut contains:
378/// - Input values: The boundary between the cut and the rest of the circuit
379/// - Operations: The logic operations within the cut boundary
380/// - Root operation: The output-driving operation of the cut
381///
382/// Cuts are used in combinational logic optimization to identify regions that
383/// can be optimized and replaced with more efficient implementations.
384class Cut {
385 /// Cached truth table for this cut.
386 /// Computed lazily when first accessed to avoid unnecessary computation.
387 mutable std::optional<BinaryTruthTable> truthTable;
388
389 /// Cached NPN canonical form for this cut.
390 /// Computed lazily from the truth table when first accessed.
391 mutable std::optional<NPNClass> npnClass;
392
393 std::optional<MatchedPattern> matchedPattern;
394
395 /// Root index in LogicNetwork (0 indicates no root for a trivial cut).
396 /// The root node produces the output of the cut.
397 uint32_t rootIndex = 0;
398
399 /// Operand cuts used to create this cut (for lazy TT computation).
400 /// Stored to enable fast incremental truth table computation after
401 /// duplicate removal. Using raw pointers is safe since cuts are allocated
402 /// via bump allocator and live for the duration of enumeration.
403 llvm::SmallVector<const Cut *, 3> operandCuts;
404
405public:
406 /// External inputs to this cut (cut boundary).
407 /// Stored as LogicNetwork indices for efficient operations.
408 llvm::SmallVector<uint32_t, 6> inputs;
409
410 /// Check if this cut represents a trivial cut.
411 /// A trivial cut has no root operation and exactly one input.
412 bool isTrivialCut() const;
413
414 /// Get the root index in the LogicNetwork.
415 uint32_t getRootIndex() const { return rootIndex; }
416
417 /// Set the root index of this cut.
418 void setRootIndex(uint32_t idx) { rootIndex = idx; }
419
420 /// Check if this cut dominates another cut.
421 bool dominates(const Cut &other) const;
422
423 /// Check if this cut dominates another sorted input set.
424 bool dominates(ArrayRef<uint32_t> otherInputs) const;
425
426 void dump(llvm::raw_ostream &os, const LogicNetwork &network) const;
427
428 /// Get the number of inputs to this cut.
429 unsigned getInputSize() const;
430
431 /// Get the number of outputs from root operation.
432 unsigned getOutputSize(const LogicNetwork &network) const;
433
434 /// Get the truth table for this cut.
435 /// The truth table represents the boolean function computed by this cut.
436 const std::optional<BinaryTruthTable> &getTruthTable() const {
437 return truthTable;
438 }
439
440 /// Compute and cache the truth table for this cut using the LogicNetwork.
441 void computeTruthTable(const LogicNetwork &network);
442
443 /// Compute truth table using fast incremental method from operand cuts.
444 /// This is much faster than simulation-based computation.
445 /// Requires that operand cuts have already been set via setOperandCuts.
446 void computeTruthTableFromOperands(const LogicNetwork &network);
447
448 /// Set the truth table directly (used for incremental computation).
449 void setTruthTable(BinaryTruthTable tt) { truthTable.emplace(std::move(tt)); }
450
451 /// Set operand cuts for lazy truth table computation.
452 void setOperandCuts(ArrayRef<const Cut *> cuts) {
453 operandCuts.assign(cuts.begin(), cuts.end());
454 }
455
456 /// Get operand cuts (for fast TT computation).
457 ArrayRef<const Cut *> getOperandCuts() const { return operandCuts; }
458
459 /// Get the NPN canonical form for this cut.
460 /// This is used for efficient pattern matching against library components.
461 const NPNClass &getNPNClass() const;
462
463 /// Get the permutated inputs for this cut based on the given pattern NPN.
464 /// Returns indices into the inputs vector.
465 void
466 getPermutatedInputIndices(const NPNClass &patternNPN,
467 SmallVectorImpl<unsigned> &permutedIndices) const;
468
469 /// Get arrival times for each input of this cut.
470 /// Returns failure if any input doesn't have a valid matched pattern.
471 LogicalResult getInputArrivalTimes(CutEnumerator &enumerator,
472 SmallVectorImpl<DelayType> &results) const;
473
474 /// Matched pattern for this cut.
478
479 /// Get the matched pattern for this cut.
480 const std::optional<MatchedPattern> &getMatchedPattern() const {
481 return matchedPattern;
482 }
483};
484
485/// Manages a collection of cuts for a single logic node using priority cuts
486/// algorithm.
487///
488/// Each node in the combinational logic network can have multiple cuts
489/// representing different ways to group it with surrounding logic. The CutSet
490/// manages these cuts and selects the best one based on the optimization
491/// strategy (area or timing).
492///
493/// The priority cuts algorithm maintains a bounded set of the most promising
494/// cuts to avoid exponential explosion while ensuring good optimization
495/// results.
496class CutSet {
497private:
498 llvm::SmallVector<Cut *, 12> cuts; ///< Collection of cuts for this node
499 Cut *bestCut = nullptr;
500 bool isFrozen = false; ///< Whether cut set is finalized
501
502public:
503 /// Check if this cut set has a valid matched pattern.
504 bool isMatched() const { return bestCut; }
505
506 /// Get the cut associated with the best matched pattern.
507 Cut *getBestMatchedCut() const;
508
509 /// Finalize the cut set by removing duplicates and selecting the best
510 /// pattern.
511 void finalize(
512 const CutRewriterOptions &options,
513 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut,
514 const LogicNetwork &logicNetwork);
515
516 /// Get the number of cuts in this set.
517 unsigned size() const;
518
519 /// Add a new cut to this set using bump allocator.
520 /// NOTE: The cut set must not be frozen
521 void addCut(Cut *cut);
522
523 /// Get read-only access to all cuts in this set.
524 ArrayRef<Cut *> getCuts() const;
525};
526
527/// Configuration options for the cut-based rewriting algorithm.
528///
529/// These options control various aspects of the rewriting process including
530/// optimization strategy, resource limits, and algorithmic parameters.
532 /// Optimization strategy (area vs. timing).
534
535 /// Maximum number of inputs allowed for any cut.
536 /// Larger cuts provide more optimization opportunities but increase
537 /// computational complexity exponentially.
539
540 /// Maximum number of cuts to maintain per logic node.
541 /// The priority cuts algorithm keeps only the most promising cuts
542 /// to prevent exponential explosion.
544
545 /// Fail if there is a root operation that has no matching pattern.
546 bool allowNoMatch = false;
547
548 /// Put arrival times to rewritten operations.
549 bool attachDebugTiming = false;
550
551 /// Run priority cuts enumeration and dump the cut sets.
552 bool testPriorityCuts = false;
553};
554
555//===----------------------------------------------------------------------===//
556// Cut Enumeration Engine
557//===----------------------------------------------------------------------===//
558
559/// Cut enumeration engine for combinational logic networks.
560///
561/// The CutEnumerator is responsible for generating cuts for each node in a
562/// combinational logic network. It uses a priority cuts algorithm to maintain a
563/// bounded set of promising cuts while avoiding exponential explosion.
564///
565/// The enumeration process works by:
566/// 1. Visiting nodes in topological order
567/// 2. For each node, combining cuts from its inputs
568/// 3. Matching generated cuts against available patterns
569/// 4. Maintaining only the most promising cuts per node
571public:
572 /// Constructor for cut enumerator.
574
575 /// Enumerate cuts for all nodes in the given module.
576 ///
577 /// This is the main entry point that orchestrates the cut enumeration
578 /// process. It visits all operations in the module and generates cuts
579 /// for combinational logic operations.
580 LogicalResult enumerateCuts(
581 Operation *topOp,
582 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut =
583 [](const Cut &) { return std::nullopt; });
584
585 /// Create a new cut set for an index.
586 /// The index must not already have a cut set.
587 CutSet *createNewCutSet(uint32_t index);
588
589 /// Get the cut set for a specific index.
590 /// If not found, it means no cuts have been generated for this value yet.
591 /// In that case return a trivial cut set.
592 const CutSet *getCutSet(uint32_t index);
593
594 /// Clear all cut sets and reset the enumerator.
595 void clear();
596
597 void dump() const;
598
599 /// Get cut sets (indexed by LogicNetwork index).
600 const llvm::DenseMap<uint32_t, CutSet *> &getCutSets() const {
601 return cutSets;
602 }
603
604 /// Get the processing order.
605 ArrayRef<uint32_t> getProcessingOrder() const { return processingOrder; }
606
607private:
608 /// Visit a combinational logic operation and generate cuts.
609 /// This handles the core cut enumeration logic for operations
610 /// like AND, OR, XOR, etc.
611 LogicalResult visitLogicOp(uint32_t nodeIndex);
612
613 /// Maps indices to their associated cut sets.
614 /// CutSets are allocated from the bump allocator.
615 llvm::DenseMap<uint32_t, CutSet *> cutSets;
616
617 /// Typed bump allocators for fast allocation with destructors.
618 llvm::SpecificBumpPtrAllocator<Cut> cutAllocator;
619 llvm::SpecificBumpPtrAllocator<CutSet> cutSetAllocator;
620
621 /// Indices in processing order.
622 llvm::SmallVector<uint32_t> processingOrder;
623
624 /// Configuration options for cut enumeration.
626
627 /// Function to match cuts against available patterns.
628 /// Set during enumeration and used when finalizing cut sets.
629 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut;
630
631 /// Flat logic network representation used during enumeration/rewrite.
633
634public:
635 /// Get the logic network (read-only).
636 const LogicNetwork &getLogicNetwork() const { return logicNetwork; }
637
638 /// Get the logic network (mutable).
640};
641
642/// Base class for cut rewriting patterns used in combinational logic
643/// optimization.
644///
645/// A CutRewritePattern represents a library component or optimization pattern
646/// that can replace cuts in the combinational logic network. Each pattern
647/// defines:
648/// - How to recognize matching cuts and compute area/delay metrics
649/// - How to transform/replace the matched cuts
650///
651/// Patterns can use truth table matching for efficient recognition or
652/// implement custom matching logic for more complex cases.
654 CutRewritePattern(mlir::MLIRContext *context) : context(context) {}
655 /// Virtual destructor for base class.
656 virtual ~CutRewritePattern() = default;
657
658 /// Check if a cut matches this pattern and compute area/delay metrics.
659 ///
660 /// This method is called to determine if a cut can be replaced by this
661 /// pattern. If the cut matches, it should return a MatchResult containing
662 /// the area and per-input delays for this specific cut.
663 ///
664 /// If useTruthTableMatcher() returns true, this method is only
665 /// called for cuts with matching truth tables.
666 virtual std::optional<MatchResult> match(CutEnumerator &enumerator,
667 const Cut &cut) const = 0;
668
669 /// Specify truth tables that this pattern can match.
670 ///
671 /// If this method returns true, the pattern matcher will use truth table
672 /// comparison for efficient pre-filtering. Only cuts with matching truth
673 /// tables will be passed to the match() method. If it returns false, the
674 /// pattern will be checked against all cuts regardless of their truth tables.
675 /// This is useful for patterns that match regardless of their truth tables,
676 /// such as LUT-based patterns.
677 virtual bool
678 useTruthTableMatcher(SmallVectorImpl<NPNClass> &matchingNPNClasses) const;
679
680 /// Return a new operation that replaces the matched cut.
681 ///
682 /// Unlike MLIR's RewritePattern framework which allows arbitrary in-place
683 /// modifications, this method creates a new operation to replace the matched
684 /// cut rather than modifying existing operations. This constraint exists
685 /// because the cut enumerator maintains references to operations throughout
686 /// the circuit, making it safe to only replace the root operation of each
687 /// cut while preserving all other operations unchanged.
688 virtual FailureOr<Operation *> rewrite(mlir::OpBuilder &builder,
689 CutEnumerator &enumerator,
690 const Cut &cut) const = 0;
691
692 /// Get the number of outputs this pattern produces.
693 virtual unsigned getNumOutputs() const = 0;
694
695 /// Get the name of this pattern. Used for debugging.
696 virtual StringRef getPatternName() const { return "<unnamed>"; }
697
698 /// Get location for this pattern(optional).
699 virtual LocationAttr getLoc() const { return mlir::UnknownLoc::get(context); }
700
701 mlir::MLIRContext *getContext() const { return context; }
702
703private:
704 mlir::MLIRContext *context;
705};
706
707/// Manages a collection of rewriting patterns for combinational logic
708/// optimization.
709///
710/// This class organizes and provides efficient access to rewriting patterns
711/// used during cut-based optimization. It maintains:
712/// - A collection of all available patterns
713/// - Fast lookup tables for truth table-based matching
714/// - Separation of truth table vs. custom matching patterns
715///
716/// The pattern set is used by the CutRewriter to find suitable replacements
717/// for cuts in the combinational logic network.
719public:
720 /// Constructor that takes ownership of the provided patterns.
721 ///
722 /// During construction, patterns are analyzed and organized for efficient
723 /// lookup. Truth table matchers are indexed by their NPN canonical forms.
725 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns);
726
727private:
728 /// Owned collection of all rewriting patterns.
729 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns;
730
731 /// Fast lookup table mapping NPN canonical forms to matching patterns.
732 /// Each entry maps a truth table and input size to patterns that can handle
733 /// it.
734 DenseMap<std::pair<APInt, unsigned>,
735 SmallVector<std::pair<NPNClass, const CutRewritePattern *>>>
737
738 /// Patterns that use custom matching logic instead of NPN lookup.
739 /// These patterns are checked against every cut.
740 SmallVector<const CutRewritePattern *, 4> nonNPNPatterns;
741
742 /// CutRewriter needs access to internal data structures for pattern matching.
743 friend class CutRewriter;
744};
745
746/// Main cut-based rewriting algorithm for combinational logic optimization.
747///
748/// The CutRewriter implements a cut-based rewriting algorithm that:
749/// 1. Enumerates cuts in the combinational logic network using a priority cuts
750/// algorithm
751/// 2. Matches cuts against available rewriting patterns
752/// 3. Selects optimal patterns based on area or timing objectives
753/// 4. Rewrites the circuit using the selected patterns
754///
755/// The algorithm processes the network in topological order, building up cut
756/// sets for each node and selecting the best implementation based on the
757/// specified optimization strategy.
758///
759/// Usage example:
760/// ```cpp
761/// CutRewriterOptions options;
762/// options.strategy = OptimizationStrategy::Area;
763/// options.maxCutInputSize = 4;
764/// options.maxCutSizePerRoot = 8;
765///
766/// CutRewritePatternSet patterns(std::move(optimizationPatterns));
767/// CutRewriter rewriter(module, options, patterns);
768///
769/// if (failed(rewriter.run())) {
770/// // Handle rewriting failure
771/// }
772/// ```
774public:
775 /// Constructor for the cut rewriter.
778
779 /// Execute the complete cut-based rewriting algorithm.
780 ///
781 /// This method orchestrates the entire rewriting process:
782 /// 1. Enumerate cuts for all nodes in the combinational logic
783 /// 2. Match cuts against available patterns
784 /// 3. Select optimal patterns based on strategy
785 /// 4. Rewrite the circuit with selected patterns
786 LogicalResult run(Operation *topOp);
787
788private:
789 /// Enumerate cuts for all nodes in the given module.
790 /// Note: This preserves module boundaries and does not perform
791 /// rewriting across the hierarchy.
792 LogicalResult enumerateCuts(Operation *topOp);
793
794 /// Find patterns that match a cut's truth table.
795 ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
796 getMatchingPatternsFromTruthTable(const Cut &cut) const;
797
798 /// Match a cut against available patterns and compute arrival time.
799 std::optional<MatchedPattern> patternMatchCut(const Cut &cut);
800
801 /// Perform the actual circuit rewriting using selected patterns.
802 LogicalResult runBottomUpRewrite(Operation *topOp);
803
804 /// Configuration options
806
807 /// Available rewriting patterns
809
811};
812
813} // namespace synth
814} // namespace circt
815
816#endif // CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
assert(baseType &&"element must be base type")
RewritePatternSet pattern
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
LogicNetwork & getLogicNetwork()
Get the logic network (mutable).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
llvm::SpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
llvm::SpecificBumpPtrAllocator< CutSet > cutSetAllocator
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
Manages a collection of rewriting patterns for combinational logic optimization.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
Main cut-based rewriting algorithm for combinational logic optimization.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
const CutRewritePatternSet & patterns
Available rewriting patterns.
CutEnumerator cutEnumerator
CutRewriter(const CutRewriterOptions &options, CutRewritePatternSet &patterns)
Constructor for the cut rewriter.
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
bool isMatched() const
Check if this cut set has a valid matched pattern.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
void setRootIndex(uint32_t idx)
Set the root index of this cut.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
void getPermutatedInputIndices(const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
void setOperandCuts(ArrayRef< const Cut * > cuts)
Set operand cuts for lazy truth table computation.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another cut.
unsigned getInputSize() const
Get the number of inputs to this cut.
ArrayRef< const Cut * > getOperandCuts() const
Get operand cuts (for fast TT computation).
void computeTruthTable(const LogicNetwork &network)
Compute and cache the truth table for this cut using the LogicNetwork.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
Value getValue(Signal signal) const
Get the value for a given Signal (extracts index from Signal).
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
LogicNetworkGate & getGate(uint32_t index)
Get mutable reference to gate at index.
Signal getSignal(Value value, bool inverted) const
Get a Signal for a value.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, llvm::ArrayRef< Signal > operands={})
Add a gate using op->getResult(0) as the result value.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
static Signal getConstant0()
Get a LogicEdge representing constant 0.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
size_t size() const
Get the total number of nodes in the network.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
static Signal getConstant1()
Get a LogicEdge representing constant 1 (constant 0 inverted).
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
MatchedPattern(const CutRewritePattern *pattern, SmallVector< DelayType, 1 > arrivalTimes, double area)
Constructor for a valid matched pattern.
double getArea() const
Get the area cost of using this pattern.
MatchedPattern()=default
Default constructor creates an invalid matched pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
int64_t DelayType
Definition CutRewriter.h:39
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:44
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
Represents a boolean function as a truth table.
Definition TruthTable.h:39
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:104
Base class for cut rewriting patterns used in combinational logic optimization.
virtual StringRef getPatternName() const
Get the name of this pattern. Used for debugging.
virtual FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const =0
Return a new operation that replaces the matched cut.
virtual ~CutRewritePattern()=default
Virtual destructor for base class.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
virtual std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const =0
Check if a cut matches this pattern and compute area/delay metrics.
mlir::MLIRContext * getContext() const
mlir::MLIRContext * context
CutRewritePattern(mlir::MLIRContext *context)
virtual unsigned getNumOutputs() const =0
Get the number of outputs this pattern produces.
virtual LocationAttr getLoc() const
Get location for this pattern(optional).
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
LogicNetworkGate(Operation *op, Kind kind, llvm::ArrayRef< Signal > operands={})
unsigned getNumFanins() const
Get the number of fanin edges based on kind.
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
llvm::PointerIntPair< Operation *, 3, Kind > opAndKind
Operation pointer and kind packed together.
bool isLogicGate() const
Check if this is a logic gate that can be part of a cut.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Identity
Identity gate (used for 1-input inverter)
@ Maj3
Majority gate (3-input, mig::MajOp)
@ PrimaryInput
Primary input to the network.
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
void setDelayRef(ArrayRef< DelayType > delays)
Set delays by reference (zero-cost for static/cached delays).
MatchResult()=default
Default constructor.
void setOwnedDelays(SmallVector< DelayType, 6 > delays)
Set delays by transferring ownership (for dynamically computed delays).
double area
Area cost of implementing this cut with the pattern.
ArrayRef< DelayType > borrowedDelays
Borrowed delays (used when ownedDelays is empty).
std::optional< SmallVector< DelayType, 6 > > ownedDelays
Owned delays (used when present).
MatchResult(double area, ArrayRef< DelayType > delays)
Constructor with area and delays (by reference).
ArrayRef< DelayType > getDelays() const
Get all delays as an ArrayRef.
Edge representation in the logic network.
Definition CutRewriter.h:74
uint32_t getIndex() const
Get the node index (without the inversion bit).
Definition CutRewriter.h:83
Signal operator!() const
Create an inverted version of this edge.
Definition CutRewriter.h:94
bool operator<(const Signal &other) const
Definition CutRewriter.h:98
bool operator!=(const Signal &other) const
Definition CutRewriter.h:97
Signal flipInversion() const
Definition CutRewriter.h:91
uint32_t getRaw() const
Get the raw data (index << 1 | inverted).
Definition CutRewriter.h:89
bool isInverted() const
Check if this edge is inverted.
Definition CutRewriter.h:86
bool operator==(const Signal &other) const
Definition CutRewriter.h:96
Signal(uint32_t index, bool inverted)
Definition CutRewriter.h:78
Signal(uint32_t raw)
Definition CutRewriter.h:80