18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/LogicalResult.h"
22#include "llvm/Support/MathExtras.h"
29#define GEN_PASS_DEF_BALANCEMUX
30#include "circt/Dialect/Comb/Passes.h.inc"
38 unsigned muxChainThreshold;
42 MuxChainWithComparison(MLIRContext *context,
unsigned muxChainThreshold)
44 muxChainThreshold(muxChainThreshold) {}
45 LogicalResult matchAndRewrite(
MuxOp rootMux,
46 PatternRewriter &rewriter)
const override {
47 auto fn = [muxChainThreshold = muxChainThreshold](
size_t indexWidth,
51 if (numEntries >= muxChainThreshold)
52 return MuxChainWithComparisonFoldingStyle::BalancedMuxTree;
53 return MuxChainWithComparisonFoldingStyle::None;
72 unsigned muxChainThreshold;
75 PriorityMuxReshape(MLIRContext *context,
unsigned muxChainThreshold)
77 muxChainThreshold(muxChainThreshold) {}
79 LogicalResult matchAndRewrite(
MuxOp op,
80 PatternRewriter &rewriter)
const override;
84 std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Location>>
85 collectChain(
MuxOp op,
bool isFalseSide)
const;
88 Value buildBalancedPriorityMux(PatternRewriter &rewriter,
89 ArrayRef<Value> conditions,
90 ArrayRef<Value> results, Value defaultValue,
91 ArrayRef<Location> locs)
const;
100PriorityMuxReshape::matchAndRewrite(
MuxOp op, PatternRewriter &rewriter)
const {
103 if (
auto userMux = dyn_cast<MuxOp>(*op->user_begin()))
107 auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
108 auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
109 if ((trueMux && falseMux) || (!trueMux && !falseMux))
111 bool useFalseSideChain = falseMux;
113 auto [conditions, results, locs] = collectChain(op, useFalseSideChain);
114 if (conditions.size() < muxChainThreshold)
117 if (!useFalseSideChain) {
119 for (
auto &cond : conditions) {
126 LDBG() <<
"Rebalanced priority mux with " << conditions.size()
127 <<
" conditions, using " << (useFalseSideChain ?
"false" :
"true")
130 assert(conditions.size() + 1 == results.size() &&
131 "Expected one more result than conditions");
132 ArrayRef<Value> resultsRef(results);
135 Value balancedTree = buildBalancedPriorityMux(
136 rewriter, conditions, resultsRef.drop_back(), resultsRef.back(), locs);
141std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Location>>
142PriorityMuxReshape::collectChain(
MuxOp op,
bool isFalseSide)
const {
143 SmallVector<Value> chainConditions, chainResults;
144 DenseSet<Value> seenConditions;
145 SmallVector<Location> chainLocs;
147 auto chainMux = isFalseSide ? op.getFalseValue().getDefiningOp<
MuxOp>()
148 : op.getTrueValue().getDefiningOp<
MuxOp>();
151 return {chainConditions, chainResults, chainLocs};
154 auto getChainResult = [&](
MuxOp mux) -> Value {
155 return isFalseSide ? mux.getTrueValue() : mux.getFalseValue();
158 auto getChainNext = [&](
MuxOp mux) -> Value {
159 return isFalseSide ? mux.getFalseValue() : mux.getTrueValue();
162 auto getRootResult = [&]() -> Value {
163 return isFalseSide ? op.getTrueValue() : op.getFalseValue();
167 seenConditions.insert(op.getCond());
168 chainConditions.push_back(op.getCond());
169 chainResults.push_back(getRootResult());
170 chainLocs.push_back(op.getLoc());
173 MuxOp currentMux = chainMux;
176 if (seenConditions.insert(currentMux.getCond()).second) {
177 chainConditions.push_back(currentMux.getCond());
178 chainResults.push_back(getChainResult(currentMux));
179 chainLocs.push_back(currentMux.getLoc());
182 auto nextMux = getChainNext(currentMux).getDefiningOp<
MuxOp>();
183 if (!nextMux || !nextMux->hasOneUse()) {
185 chainResults.push_back(getChainNext(currentMux));
188 currentMux = nextMux;
191 return {chainConditions, chainResults, chainLocs};
202Value PriorityMuxReshape::buildBalancedPriorityMux(
203 PatternRewriter &rewriter, ArrayRef<Value> conditions,
204 ArrayRef<Value> results, Value defaultValue,
205 ArrayRef<Location> locs)
const {
206 size_t size = conditions.size();
211 return rewriter.createOrFold<
MuxOp>(locs.front(), conditions.front(),
212 results.front(), defaultValue);
218 unsigned mid = llvm::divideCeil(size, 2);
220 auto loc = rewriter.getFusedLoc(ArrayRef<Location>(locs).take_front(mid));
224 Value leftTree = buildBalancedPriorityMux(
225 rewriter, conditions.take_front(mid), results.take_front(mid),
226 results.take_front(mid).back(), locs.take_front(mid));
228 Value rightTree = buildBalancedPriorityMux(
229 rewriter, conditions.drop_front(mid), results.drop_front(mid),
230 defaultValue, locs.drop_front(mid));
234 rewriter.createOrFold<
OrOp>(loc, conditions.take_front(mid),
true);
237 return rewriter.create<
MuxOp>(loc, combinedCond, leftTree, rightTree);
242 using BalanceMuxBase::BalanceMuxBase;
245 Operation *op = getOperation();
246 MLIRContext *context = op->getContext();
248 RewritePatternSet
patterns(context);
249 patterns.add<MuxChainWithComparison, PriorityMuxReshape>(context,
252 if (failed(applyPatternsGreedily(op, std::move(
patterns))))
253 return signalPassFailure();
assert(baseType &&"element must be base type")
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Pass that performs enhanced mux chain optimizations.
void runOnOperation() override