16#ifndef CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
17#define CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
22#include "mlir/IR/Operation.h"
23#include "llvm/ADT/APInt.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/PointerIntPair.h"
26#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Allocator.h"
29#include "llvm/Support/LogicalResult.h"
30#include "llvm/Support/raw_ostream.h"
53FailureOr<BinaryTruthTable>
getTruthTable(ValueRange values, Block *block);
78 Signal(uint32_t index,
bool inverted)
79 :
data((index << 1) | (inverted ? 1 : 0)) {}
134 llvm::ArrayRef<Signal> operands = {})
136 assert(operands.size() <= 3 &&
"Too many operands for LogicNetworkGate");
137 for (
size_t i = 0; i < operands.size(); ++i)
138 edges[i] = operands[i];
161 llvm_unreachable(
"Unknown gate kind");
221 uint32_t
getIndex(Value value)
const;
228 Value
getValue(uint32_t index)
const;
231 void getValues(ArrayRef<uint32_t> indices,
232 SmallVectorImpl<Value> &values)
const;
262 llvm::ArrayRef<Signal> operands = {});
266 llvm::ArrayRef<Signal> operands = {}) {
267 return addGate(op, kind, op->getResult(0), operands);
285 llvm::SmallVector<LogicNetworkGate>
gates;
347 SmallVector<DelayType, 1>
424 bool dominates(ArrayRef<uint32_t> otherInputs)
const;
467 SmallVectorImpl<unsigned> &permutedIndices)
const;
472 SmallVectorImpl<DelayType> &results)
const;
498 llvm::SmallVector<Cut *, 12>
cuts;
513 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut,
517 unsigned size()
const;
524 ArrayRef<Cut *>
getCuts()
const;
582 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)>
matchCut =
583 [](
const Cut &) {
return std::nullopt; });
600 const llvm::DenseMap<uint32_t, CutSet *> &
getCutSets()
const {
629 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)>
matchCut;
667 const Cut &cut)
const = 0;
688 virtual FailureOr<Operation *>
rewrite(mlir::OpBuilder &builder,
690 const Cut &cut)
const = 0;
699 virtual LocationAttr
getLoc()
const {
return mlir::UnknownLoc::get(
context); }
725 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4>
patterns);
729 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4>
patterns;
734 DenseMap<std::pair<APInt, unsigned>,
735 SmallVector<std::pair<NPNClass, const CutRewritePattern *>>>
786 LogicalResult
run(Operation *topOp);
795 ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
assert(baseType &&"element must be base type")
RewritePatternSet pattern
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
LogicNetwork & getLogicNetwork()
Get the logic network (mutable).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
llvm::SpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
llvm::SpecificBumpPtrAllocator< CutSet > cutSetAllocator
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
Manages a collection of rewriting patterns for combinational logic optimization.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
Main cut-based rewriting algorithm for combinational logic optimization.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
const CutRewritePatternSet & patterns
Available rewriting patterns.
CutEnumerator cutEnumerator
CutRewriter(const CutRewriterOptions &options, CutRewritePatternSet &patterns)
Constructor for the cut rewriter.
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
bool isMatched() const
Check if this cut set has a valid matched pattern.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
void setRootIndex(uint32_t idx)
Set the root index of this cut.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
void getPermutatedInputIndices(const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
void setOperandCuts(ArrayRef< const Cut * > cuts)
Set operand cuts for lazy truth table computation.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another cut.
unsigned getInputSize() const
Get the number of inputs to this cut.
ArrayRef< const Cut * > getOperandCuts() const
Get operand cuts (for fast TT computation).
void computeTruthTable(const LogicNetwork &network)
Compute and cache the truth table for this cut using the LogicNetwork.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
Value getValue(Signal signal) const
Get the value for a given Signal (extracts index from Signal).
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
LogicNetworkGate & getGate(uint32_t index)
Get mutable reference to gate at index.
Signal getSignal(Value value, bool inverted) const
Get a Signal for a value.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, llvm::ArrayRef< Signal > operands={})
Add a gate using op->getResult(0) as the result value.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
static Signal getConstant0()
Get a LogicEdge representing constant 0.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
size_t size() const
Get the total number of nodes in the network.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
static Signal getConstant1()
Get a LogicEdge representing constant 1 (constant 0 inverted).
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
MatchedPattern(const CutRewritePattern *pattern, SmallVector< DelayType, 1 > arrivalTimes, double area)
Constructor for a valid matched pattern.
double getArea() const
Get the area cost of using this pattern.
MatchedPattern()=default
Default constructor creates an invalid matched pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
OptimizationStrategy
Optimization strategy.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Represents a boolean function as a truth table.
Represents the canonical form of a boolean function under NPN equivalence.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual StringRef getPatternName() const
Get the name of this pattern. Used for debugging.
virtual FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const =0
Return a new operation that replaces the matched cut.
virtual ~CutRewritePattern()=default
Virtual destructor for base class.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
virtual std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const =0
Check if a cut matches this pattern and compute area/delay metrics.
mlir::MLIRContext * getContext() const
mlir::MLIRContext * context
CutRewritePattern(mlir::MLIRContext *context)
virtual unsigned getNumOutputs() const =0
Get the number of outputs this pattern produces.
virtual LocationAttr getLoc() const
Get location for this pattern(optional).
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
LogicNetworkGate(Operation *op, Kind kind, llvm::ArrayRef< Signal > operands={})
unsigned getNumFanins() const
Get the number of fanin edges based on kind.
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
llvm::PointerIntPair< Operation *, 3, Kind > opAndKind
Operation pointer and kind packed together.
bool isLogicGate() const
Check if this is a logic gate that can be part of a cut.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Identity
Identity gate (used for 1-input inverter)
@ Maj3
Majority gate (3-input, mig::MajOp)
@ PrimaryInput
Primary input to the network.
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
void setDelayRef(ArrayRef< DelayType > delays)
Set delays by reference (zero-cost for static/cached delays).
MatchResult()=default
Default constructor.
void setOwnedDelays(SmallVector< DelayType, 6 > delays)
Set delays by transferring ownership (for dynamically computed delays).
double area
Area cost of implementing this cut with the pattern.
ArrayRef< DelayType > borrowedDelays
Borrowed delays (used when ownedDelays is empty).
std::optional< SmallVector< DelayType, 6 > > ownedDelays
Owned delays (used when present).
MatchResult(double area, ArrayRef< DelayType > delays)
Constructor with area and delays (by reference).
ArrayRef< DelayType > getDelays() const
Get all delays as an ArrayRef.
Edge representation in the logic network.
uint32_t getIndex() const
Get the node index (without the inversion bit).
Signal operator!() const
Create an inverted version of this edge.
bool operator<(const Signal &other) const
bool operator!=(const Signal &other) const
Signal flipInversion() const
uint32_t getRaw() const
Get the raw data (index << 1 | inverted).
bool isInverted() const
Check if this edge is inverted.
bool operator==(const Signal &other) const
Signal(uint32_t index, bool inverted)