18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/LogicalResult.h"
22#include "llvm/Support/MathExtras.h"
29#define GEN_PASS_DEF_BALANCEMUX
30#include "circt/Dialect/Comb/Passes.h.inc"
38 unsigned muxChainThreshold;
42 MuxChainWithComparison(MLIRContext *
context,
unsigned muxChainThreshold)
44 muxChainThreshold(muxChainThreshold) {}
45 LogicalResult matchAndRewrite(
MuxOp rootMux,
46 PatternRewriter &rewriter)
const override {
47 auto fn = [muxChainThreshold = muxChainThreshold](
size_t indexWidth,
51 if (numEntries >= muxChainThreshold)
52 return MuxChainWithComparisonFoldingStyle::BalancedMuxTree;
53 return MuxChainWithComparisonFoldingStyle::None;
72 unsigned muxChainThreshold;
75 PriorityMuxReshape(MLIRContext *
context,
unsigned muxChainThreshold)
77 muxChainThreshold(muxChainThreshold) {}
79 LogicalResult matchAndRewrite(
MuxOp op,
80 PatternRewriter &rewriter)
const override;
84 std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Location>>
85 collectChain(
MuxOp op,
bool isFalseSide)
const;
88 Value buildBalancedPriorityMux(PatternRewriter &rewriter,
89 ArrayRef<Value> conditions,
90 ArrayRef<Value> results, Value defaultValue,
91 ArrayRef<Location> locs)
const;
100PriorityMuxReshape::matchAndRewrite(
MuxOp op, PatternRewriter &rewriter)
const {
103 if (
auto userMux = dyn_cast<MuxOp>(*op->user_begin()))
107 auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
108 auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
109 if ((trueMux && falseMux) || (!trueMux && !falseMux))
111 bool useFalseSideChain = falseMux;
113 auto [conditions, results, locs] = collectChain(op, useFalseSideChain);
114 if (conditions.size() < muxChainThreshold)
117 if (!useFalseSideChain) {
119 for (
auto &cond : conditions) {
126 LDBG() <<
"Rebalanced priority mux with " << conditions.size()
127 <<
" conditions, using " << (useFalseSideChain ?
"false" :
"true")
130 assert(conditions.size() + 1 == results.size() &&
131 "Expected one more result than conditions");
132 ArrayRef<Value> resultsRef(results);
135 Value balancedTree = buildBalancedPriorityMux(
136 rewriter, conditions, resultsRef.drop_back(), resultsRef.back(), locs);
141std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Location>>
142PriorityMuxReshape::collectChain(
MuxOp op,
bool isFalseSide)
const {
143 SmallVector<Value> chainConditions, chainResults;
144 DenseSet<Value> seenConditions;
145 SmallVector<Location> chainLocs;
147 auto chainMux = isFalseSide ? op.getFalseValue().getDefiningOp<
MuxOp>()
148 : op.getTrueValue().getDefiningOp<
MuxOp>();
151 return {chainConditions, chainResults, chainLocs};
154 auto getChainResult = [&](
MuxOp mux) -> Value {
155 return isFalseSide ? mux.getTrueValue() : mux.getFalseValue();
158 auto getChainNext = [&](
MuxOp mux) -> Value {
159 return isFalseSide ? mux.getFalseValue() : mux.getTrueValue();
162 auto getRootResult = [&]() -> Value {
163 return isFalseSide ? op.getTrueValue() : op.getFalseValue();
167 seenConditions.insert(op.getCond());
168 chainConditions.push_back(op.getCond());
169 chainResults.push_back(getRootResult());
170 chainLocs.push_back(op.getLoc());
173 MuxOp currentMux = chainMux;
176 if (seenConditions.insert(currentMux.getCond()).second) {
177 chainConditions.push_back(currentMux.getCond());
178 chainResults.push_back(getChainResult(currentMux));
179 chainLocs.push_back(currentMux.getLoc());
182 auto nextMux = getChainNext(currentMux).getDefiningOp<
MuxOp>();
183 if (!nextMux || !nextMux->hasOneUse()) {
185 chainResults.push_back(getChainNext(currentMux));
188 currentMux = nextMux;
191 return {chainConditions, chainResults, chainLocs};
202Value PriorityMuxReshape::buildBalancedPriorityMux(
203 PatternRewriter &rewriter, ArrayRef<Value> conditions,
204 ArrayRef<Value> results, Value defaultValue,
205 ArrayRef<Location> locs)
const {
206 size_t size = conditions.size();
211 return rewriter.createOrFold<
MuxOp>(locs.front(), conditions.front(),
212 results.front(), defaultValue);
218 unsigned mid = llvm::divideCeil(size, 2);
220 auto loc = rewriter.getFusedLoc(ArrayRef<Location>(locs).take_front(mid));
224 Value leftTree = buildBalancedPriorityMux(
225 rewriter, conditions.take_front(mid), results.take_front(mid),
226 results.take_front(mid).back(), locs.take_front(mid));
228 Value rightTree = buildBalancedPriorityMux(
229 rewriter, conditions.drop_front(mid), results.drop_front(mid),
230 defaultValue, locs.drop_front(mid));
234 rewriter.createOrFold<
OrOp>(loc, conditions.take_front(mid),
true);
237 return MuxOp::create(rewriter, loc, combinedCond, leftTree, rightTree);
248 PatternRewriter &rewriter)
const override {
249 if (!op.getTwoState())
252 SmallVector<Value> conditions;
253 for (Value operand : op.getOperands()) {
255 if (!mux || !mux.getTwoState() ||
256 !matchPattern(mux.getFalseValue(), mlir::m_Zero()))
259 conditions.push_back(mux.getCond());
266 SmallVector<Value> values(op.getOperands().drop_back());
267 Value v = op.getOperands().back();
268 while (!values.empty()) {
269 auto mux = values.pop_back_val().getDefiningOp<
comb::MuxOp>();
270 v = comb::MuxOp::create(rewriter, op.getLoc(), mux.getCond(),
271 mux.getTrueValue(), v,
true);
284 DenseSet<IntegerAttr> seenConstants;
285 for (Value v : conditions) {
286 auto icmp = v.getDefiningOp<ICmpOp>();
288 if (!icmp || icmp.getPredicate() != ICmpPredicate::eq ||
289 !matchPattern(icmp.getRhs(), mlir::m_Constant(&value)))
292 if (!seenConstants.insert(value).second)
301 using BalanceMuxBase::BalanceMuxBase;
304 Operation *op = getOperation();
305 MLIRContext *
context = op->getContext();
312 if (failed(applyPatternsGreedily(op, std::move(
patterns))))
313 return signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Pass that performs enhanced mux chain optimizations.
void runOnOperation() override
LogicalResult matchAndRewrite(OrOp op, PatternRewriter &rewriter) const override
bool areConditionsIndependent(ArrayRef< Value > conditions) const