CIRCT 22.0.0git
Loading...
Searching...
No Matches
BalanceMux.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the BalanceMux pass, which balances and optimizes mux
10// chains.
11//
12//===----------------------------------------------------------------------===//
13
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/LogicalResult.h"
22#include "llvm/Support/MathExtras.h"
23
24using namespace circt;
25using namespace comb;
26
27namespace circt {
28namespace comb {
29#define GEN_PASS_DEF_BALANCEMUX
30#include "circt/Dialect/Comb/Passes.h.inc"
31} // namespace comb
32} // namespace circt
33
34namespace {
35
36/// Mux chain with comparison folding pattern.
37class MuxChainWithComparison : public OpRewritePattern<MuxOp> {
38 unsigned muxChainThreshold;
39
40public:
41 // Set a higher benefit than PriorityEncoderReshape to run first.
42 MuxChainWithComparison(MLIRContext *context, unsigned muxChainThreshold)
43 : OpRewritePattern<MuxOp>(context, /*benefit=*/2),
44 muxChainThreshold(muxChainThreshold) {}
45 LogicalResult matchAndRewrite(MuxOp rootMux,
46 PatternRewriter &rewriter) const override {
47 auto fn = [muxChainThreshold = muxChainThreshold](size_t indexWidth,
48 size_t numEntries) {
49 // In this pattern, we consider it beneficial to fold mux chains
50 // with more than the threshold.
51 if (numEntries >= muxChainThreshold)
52 return MuxChainWithComparisonFoldingStyle::BalancedMuxTree;
53 return MuxChainWithComparisonFoldingStyle::None;
54 };
55 // Try folding on both false and true sides
56 return llvm::success(foldMuxChainWithComparison(rewriter, rootMux,
57 /*isFalseSide=*/true, fn) ||
58 foldMuxChainWithComparison(rewriter, rootMux,
59 /*isFalseSide=*/false, fn));
60 }
61};
62
63/// Rebalances a linear chain of muxes forming a priority encoder into a
64/// balanced tree structure. This reduces the depth of the mux tree from O(n)
65/// to O(log n).
66///
67/// For a priority encoder with n conditions, this transform:
68/// - Reduces depth from O(n) to O(log n) levels
69/// - Muxes: Creates exactly (n-1) muxes (same as original linear chain)
70/// - OR gates: Creates additional O(n log n) OR gates to combine
71class PriorityMuxReshape : public OpRewritePattern<MuxOp> {
72 unsigned muxChainThreshold;
73
74public:
75 PriorityMuxReshape(MLIRContext *context, unsigned muxChainThreshold)
76 : OpRewritePattern<MuxOp>(context, /*benefit=*/1),
77 muxChainThreshold(muxChainThreshold) {}
78
79 LogicalResult matchAndRewrite(MuxOp op,
80 PatternRewriter &rewriter) const override;
81
82private:
83 /// Helper function to collect a mux chain from a given side
84 std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Location>>
85 collectChain(MuxOp op, bool isFalseSide) const;
86
87 /// Build a balanced tree from the collected conditions and results
88 Value buildBalancedPriorityMux(PatternRewriter &rewriter,
89 ArrayRef<Value> conditions,
90 ArrayRef<Value> results, Value defaultValue,
91 ArrayRef<Location> locs) const;
92};
93}; // namespace
94
95//===----------------------------------------------------------------------===//
96// Implementation
97//===----------------------------------------------------------------------===//
98
99LogicalResult
100PriorityMuxReshape::matchAndRewrite(MuxOp op, PatternRewriter &rewriter) const {
101 // Make sure that we're not looking at the intermediate node in a mux tree.
102 if (op->hasOneUse())
103 if (auto userMux = dyn_cast<MuxOp>(*op->user_begin()))
104 return failure();
105
106 // Early return if both or neither side are mux chains.
107 auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
108 auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
109 if ((trueMux && falseMux) || (!trueMux && !falseMux))
110 return failure();
111 bool useFalseSideChain = falseMux;
112
113 auto [conditions, results, locs] = collectChain(op, useFalseSideChain);
114 if (conditions.size() < muxChainThreshold)
115 return failure();
116
117 if (!useFalseSideChain) {
118 // For true-side chains, we need to invert all conditions
119 for (auto &cond : conditions) {
120 cond = rewriter.createOrFold<comb::XorOp>(
121 op.getLoc(), cond,
122 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt(1, 1)), true);
123 }
124 }
125
126 LDBG() << "Rebalanced priority mux with " << conditions.size()
127 << " conditions, using " << (useFalseSideChain ? "false" : "true")
128 << "-side chain.\n";
129
130 assert(conditions.size() + 1 == results.size() &&
131 "Expected one more result than conditions");
132 ArrayRef<Value> resultsRef(results);
133
134 // Build balanced tree and replace original op
135 Value balancedTree = buildBalancedPriorityMux(
136 rewriter, conditions, resultsRef.drop_back(), resultsRef.back(), locs);
137 replaceOpAndCopyNamehint(rewriter, op, balancedTree);
138 return success();
139}
140
141std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Location>>
142PriorityMuxReshape::collectChain(MuxOp op, bool isFalseSide) const {
143 SmallVector<Value> chainConditions, chainResults;
144 DenseSet<Value> seenConditions;
145 SmallVector<Location> chainLocs;
146
147 auto chainMux = isFalseSide ? op.getFalseValue().getDefiningOp<MuxOp>()
148 : op.getTrueValue().getDefiningOp<MuxOp>();
149
150 if (!chainMux)
151 return {chainConditions, chainResults, chainLocs};
152
153 // Helper lambdas to abstract the differences between false/true side chains
154 auto getChainResult = [&](MuxOp mux) -> Value {
155 return isFalseSide ? mux.getTrueValue() : mux.getFalseValue();
156 };
157
158 auto getChainNext = [&](MuxOp mux) -> Value {
159 return isFalseSide ? mux.getFalseValue() : mux.getTrueValue();
160 };
161
162 auto getRootResult = [&]() -> Value {
163 return isFalseSide ? op.getTrueValue() : op.getFalseValue();
164 };
165
166 // Start collecting the chain
167 seenConditions.insert(op.getCond());
168 chainConditions.push_back(op.getCond());
169 chainResults.push_back(getRootResult());
170 chainLocs.push_back(op.getLoc());
171
172 // Walk down the chain collecting all conditions and results
173 MuxOp currentMux = chainMux;
174 while (currentMux) {
175 // Only add unique conditions (outer muxes have priority)
176 if (seenConditions.insert(currentMux.getCond()).second) {
177 chainConditions.push_back(currentMux.getCond());
178 chainResults.push_back(getChainResult(currentMux));
179 chainLocs.push_back(currentMux.getLoc());
180 }
181
182 auto nextMux = getChainNext(currentMux).getDefiningOp<MuxOp>();
183 if (!nextMux || !nextMux->hasOneUse()) {
184 // Add the final default value
185 chainResults.push_back(getChainNext(currentMux));
186 break;
187 }
188 currentMux = nextMux;
189 }
190
191 return {chainConditions, chainResults, chainLocs};
192}
193
194// This function recursively constructs a balanced binary tree of muxes for a
195// priority encoder. It splits the conditions and results into halves,
196// combining the left half's conditions with an OR gate to select between
197// the left subtree (which includes the default for that half) and the right
198// subtree. This transforms a linear chain like:
199// a_0 ? r_0 : a_1 ? r_1 : ... : a_n ? r_n : default
200// into a balanced tree, reducing depth from O(n) to O(log n).
201// NOLINTNEXTLINE(misc-no-recursion)
202Value PriorityMuxReshape::buildBalancedPriorityMux(
203 PatternRewriter &rewriter, ArrayRef<Value> conditions,
204 ArrayRef<Value> results, Value defaultValue,
205 ArrayRef<Location> locs) const {
206 size_t size = conditions.size();
207 // Base cases.
208 if (size == 0)
209 return defaultValue;
210 if (size == 1)
211 return rewriter.createOrFold<MuxOp>(locs.front(), conditions.front(),
212 results.front(), defaultValue);
213
214 // Recursive case: split range in half. Take the ceiling to ensure the first
215 // half is larger.
216 // TODO: Ideally the separator index should be selected based on arrival times
217 // of results.
218 unsigned mid = llvm::divideCeil(size, 2);
219 assert(mid > 0);
220 auto loc = rewriter.getFusedLoc(ArrayRef<Location>(locs).take_front(mid));
221
222 // Build left and right subtrees. Use the last result as the default for the
223 // left subtree to ensure correct priority encoding.
224 Value leftTree = buildBalancedPriorityMux(
225 rewriter, conditions.take_front(mid), results.take_front(mid),
226 results.take_front(mid).back(), locs.take_front(mid));
227
228 Value rightTree = buildBalancedPriorityMux(
229 rewriter, conditions.drop_front(mid), results.drop_front(mid),
230 defaultValue, locs.drop_front(mid));
231
232 // Combine conditions from left half with OR
233 Value combinedCond =
234 rewriter.createOrFold<OrOp>(loc, conditions.take_front(mid), true);
235
236 // Create mux that selects between left and right subtrees
237 return rewriter.create<MuxOp>(loc, combinedCond, leftTree, rightTree);
238}
239
240/// Pass that performs enhanced mux chain optimizations
241struct BalanceMuxPass : public impl::BalanceMuxBase<BalanceMuxPass> {
242 using BalanceMuxBase::BalanceMuxBase;
243
244 void runOnOperation() override {
245 Operation *op = getOperation();
246 MLIRContext *context = op->getContext();
247
248 RewritePatternSet patterns(context);
249 patterns.add<MuxChainWithComparison, PriorityMuxReshape>(context,
250 muxChainThreshold);
251
252 if (failed(applyPatternsGreedily(op, std::move(patterns))))
253 return signalPassFailure();
254 }
255};
assert(baseType &&"element must be base type")
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1
Pass that performs enhanced mux chain optimizations.
void runOnOperation() override