16#ifndef CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
17#define CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
22#include "mlir/IR/Operation.h"
23#include "llvm/ADT/APInt.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/LogicalResult.h"
27#include "llvm/Support/raw_ostream.h"
50FailureOr<BinaryTruthTable>
getTruthTable(ValueRange values, Block *block);
122 SmallVector<DelayType, 1>
174 llvm::SmallSetVector<mlir::Value, 4>
inputs;
186 mlir::Operation *
getRoot()
const;
188 void dump(llvm::raw_ostream &os)
const;
214 SmallVectorImpl<Value> &permutedInputs)
const;
219 SmallVectorImpl<DelayType> &results)
const;
245 llvm::SmallVector<Cut, 4>
cuts;
266 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut);
269 unsigned size()
const;
334 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)>
matchCut =
335 [](
const Cut &) {
return std::nullopt; });
348 llvm::MapVector<Value, std::unique_ptr<CutSet>>
takeVector();
354 const llvm::MapVector<Value, std::unique_ptr<CutSet>> &
getCutSets()
const {
360 LogicalResult
visit(Operation *op);
368 llvm::MapVector<Value, std::unique_ptr<CutSet>>
cutSets;
375 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)>
matchCut;
403 const Cut &cut)
const = 0;
424 virtual FailureOr<Operation *>
rewrite(mlir::OpBuilder &builder,
426 const Cut &cut)
const = 0;
435 virtual LocationAttr
getLoc()
const {
return mlir::UnknownLoc::get(
context); }
461 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4>
patterns);
465 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4>
patterns;
470 DenseMap<std::pair<APInt, unsigned>,
471 SmallVector<std::pair<NPNClass, const CutRewritePattern *>>>
522 LogicalResult
run(Operation *topOp);
531 ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
RewritePatternSet pattern
Cut enumeration engine for combinational logic networks.
LogicalResult visit(Operation *op)
Visit a single operation and generate cuts for it.
const CutRewriterOptions & options
Configuration options for cut enumeration.
llvm::MapVector< Value, std::unique_ptr< CutSet > > cutSets
Maps values to their associated cut sets.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
const llvm::MapVector< Value, std::unique_ptr< CutSet > > & getCutSets() const
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
LogicalResult visitLogicOp(Operation *logicOp)
Visit a combinational logic operation and generate cuts.
llvm::MapVector< Value, std::unique_ptr< CutSet > > takeVector()
Move ownership of all cut sets to caller.
CutSet * createNewCutSet(Value value)
Create a new cut set for a value.
const CutSet * getCutSet(Value value)
Get the cut set for a specific value.
Manages a collection of rewriting patterns for combinational logic optimization.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
Main cut-based rewriting algorithm for combinational logic optimization.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
const CutRewritePatternSet & patterns
Available rewriting patterns.
CutEnumerator cutEnumerator
CutRewriter(const CutRewriterOptions &options, CutRewritePatternSet &patterns)
Constructor for the cut rewriter.
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
llvm::SmallVector< Cut, 4 > cuts
Collection of cuts for this node.
void addCut(Cut cut)
Add a new cut to this set.
unsigned size() const
Get the number of cuts in this set.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut)
Finalize the cut set by removing duplicates and selecting the best pattern.
bool isMatched() const
Check if this cut set has a valid matched pattern.
bool isFrozen
Whether cut set is finalized.
ArrayRef< Cut > getCuts() const
Get read-only access to all cuts in this set.
Represents a cut in the combinational logic network.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
llvm::SmallSetVector< mlir::Operation *, 4 > operations
Operations contained within this cut.
unsigned getOutputSize() const
Get the number of outputs from root operation.
void getPermutatedInputs(const NPNClass &patternNPN, SmallVectorImpl< Value > &permutedInputs) const
Get the permutated inputs for this cut based on the given pattern NPN.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
mlir::Operation * getRoot() const
Get the root operation of this cut.
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
llvm::SmallSetVector< mlir::Value, 4 > inputs
External inputs to this cut (cut boundary).
void dump(llvm::raw_ostream &os) const
const BinaryTruthTable & getTruthTable() const
Get the truth table for this cut.
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
Cut mergeWith(const Cut &other, Operation *root) const
Merge this cut with another cut to form a new cut.
unsigned getInputSize() const
Get the number of inputs to this cut.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
unsigned getCutSize() const
Get the number of operations in this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
Cut reRoot(Operation *root) const
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
MatchedPattern(const CutRewritePattern *pattern, SmallVector< DelayType, 1 > arrivalTimes, double area)
Constructor for a valid matched pattern.
DelayType getWorstOutputArrivalTime() const
double getArea() const
Get the area cost of using this pattern.
MatchedPattern()=default
Default constructor creates an invalid matched pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
OptimizationStrategy
Optimization strategy.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Represents a boolean function as a truth table.
Represents the canonical form of a boolean function under NPN equivalence.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual StringRef getPatternName() const
Get the name of this pattern. Used for debugging.
virtual FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const =0
Return a new operation that replaces the matched cut.
virtual ~CutRewritePattern()=default
Virtual destructor for base class.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
virtual std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const =0
Check if a cut matches this pattern and compute area/delay metrics.
mlir::MLIRContext * getContext() const
mlir::MLIRContext * context
CutRewritePattern(mlir::MLIRContext *context)
virtual unsigned getNumOutputs() const =0
Get the number of outputs this pattern produces.
virtual LocationAttr getLoc() const
Get location for this pattern(optional).
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Result of matching a cut against a pattern.
void setDelayRef(ArrayRef< DelayType > delays)
Set delays by reference (zero-cost for static/cached delays).
MatchResult()=default
Default constructor.
void setOwnedDelays(SmallVector< DelayType, 6 > delays)
Set delays by transferring ownership (for dynamically computed delays).
double area
Area cost of implementing this cut with the pattern.
ArrayRef< DelayType > borrowedDelays
Borrowed delays (used when ownedDelays is empty).
std::optional< SmallVector< DelayType, 6 > > ownedDelays
Owned delays (used when present).
MatchResult(double area, ArrayRef< DelayType > delays)
Constructor with area and delays (by reference).
ArrayRef< DelayType > getDelays() const
Get all delays as an ArrayRef.