16#ifndef CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
17#define CIRCT_DIALECT_SYNTH_TRANSFORMS_CUT_REWRITER_H
22#include "mlir/IR/Operation.h"
23#include "llvm/ADT/APInt.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/LogicalResult.h"
27#include "llvm/Support/raw_ostream.h"
50FailureOr<BinaryTruthTable>
getTruthTable(ValueRange values, Block *block);
70 SmallVector<DelayType, 1>
124 llvm::SmallSetVector<mlir::Value, 4>
inputs;
136 mlir::Operation *
getRoot()
const;
138 void dump(llvm::raw_ostream &os)
const;
164 SmallVectorImpl<Value> &permutedInputs)
const;
190 llvm::SmallVector<Cut, 4>
cuts;
211 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut);
214 unsigned size()
const;
279 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)>
matchCut =
280 [](
const Cut &) {
return std::nullopt; });
293 llvm::MapVector<Value, std::unique_ptr<CutSet>>
takeVector();
302 LogicalResult
visit(Operation *op);
310 llvm::MapVector<Value, std::unique_ptr<CutSet>>
cutSets;
317 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)>
matchCut;
363 virtual FailureOr<Operation *>
rewrite(mlir::OpBuilder &builder,
373 unsigned outputIndex)
const = 0;
382 virtual LocationAttr
getLoc()
const {
return mlir::UnknownLoc::get(
context); }
408 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4>
patterns);
412 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4>
patterns;
417 DenseMap<std::pair<APInt, unsigned>,
418 SmallVector<std::pair<NPNClass, const CutRewritePattern *>>>
469 LogicalResult
run(Operation *topOp);
478 ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
RewritePatternSet pattern
Cut enumeration engine for combinational logic networks.
LogicalResult visit(Operation *op)
Visit a single operation and generate cuts for it.
const CutRewriterOptions & options
Configuration options for cut enumeration.
llvm::MapVector< Value, std::unique_ptr< CutSet > > cutSets
Maps values to their associated cut sets.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
LogicalResult visitLogicOp(Operation *logicOp)
Visit a combinational logic operation and generate cuts.
llvm::MapVector< Value, std::unique_ptr< CutSet > > takeVector()
Move ownership of all cut sets to caller.
CutSet * createNewCutSet(Value value)
Create a new cut set for a value.
const CutSet * getCutSet(Value value)
Get the cut set for a specific value.
Manages a collection of rewriting patterns for combinational logic optimization.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
SmallVector< const CutRewritePattern *, 4 > nonTruthTablePatterns
Patterns that use custom matching logic instead of truth tables.
Main cut-based rewriting algorithm for combinational logic optimization.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
const CutRewritePatternSet & patterns
Available rewriting patterns.
CutEnumerator cutEnumerator
CutRewriter(const CutRewriterOptions &options, CutRewritePatternSet &patterns)
Constructor for the cut rewriter.
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
llvm::SmallVector< Cut, 4 > cuts
Collection of cuts for this node.
void addCut(Cut cut)
Add a new cut to this set.
unsigned size() const
Get the number of cuts in this set.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut)
Finalize the cut set by removing duplicates and selecting the best pattern.
bool isMatched() const
Check if this cut set has a valid matched pattern.
bool isFrozen
Whether cut set is finalized.
ArrayRef< Cut > getCuts() const
Get read-only access to all cuts in this set.
Represents a cut in the combinational logic network.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
llvm::SmallSetVector< mlir::Operation *, 4 > operations
Operations contained within this cut.
unsigned getOutputSize() const
Get the number of outputs from root operation.
void getPermutatedInputs(const NPNClass &patternNPN, SmallVectorImpl< Value > &permutedInputs) const
Get the permutated inputs for this cut based on the given pattern NPN.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
mlir::Operation * getRoot() const
Get the root operation of this cut.
llvm::SmallSetVector< mlir::Value, 4 > inputs
External inputs to this cut (cut boundary).
void dump(llvm::raw_ostream &os) const
const BinaryTruthTable & getTruthTable() const
Get the truth table for this cut.
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
Cut mergeWith(const Cut &other, Operation *root) const
Merge this cut with another cut to form a new cut.
unsigned getInputSize() const
Get the number of inputs to this cut.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
unsigned getCutSize() const
Get the number of operations in this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
Cut reRoot(Operation *root) const
Represents a cut that has been successfully matched to a rewriting pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
MatchedPattern(const CutRewritePattern *pattern, SmallVector< DelayType, 1 > arrivalTimes)
Constructor for a valid matched pattern.
DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const
Get the delay between specific input and output pins.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
DelayType getWorstOutputArrivalTime() const
double getArea() const
Get the area cost of using this pattern.
MatchedPattern()=default
Default constructor creates an invalid matched pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
OptimizationStrategy
Optimization strategy.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Represents a boolean function as a truth table.
Represents the canonical form of a boolean function under NPN equivalence.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual StringRef getPatternName() const
Get the name of this pattern. Used for debugging.
virtual ~CutRewritePattern()=default
Virtual destructor for base class.
virtual DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const =0
Get the delay between specific input and output.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
mlir::MLIRContext * getContext() const
mlir::MLIRContext * context
virtual FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, Cut &cut) const =0
Return a new operation that replaces the matched cut.
CutRewritePattern(mlir::MLIRContext *context)
virtual double getArea() const =0
Get the area cost of this pattern.
virtual unsigned getNumOutputs() const =0
Get the number of outputs this pattern produces.
virtual bool match(const Cut &cut) const =0
Check if a cut matches this pattern.
virtual LocationAttr getLoc() const
Get location for this pattern(optional).
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.