CIRCT 23.0.0git
Loading...
Searching...
No Matches
BalanceMux.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the BalanceMux pass, which balances and optimizes mux
10// chains.
11//
12//===----------------------------------------------------------------------===//
13
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/LogicalResult.h"
22#include "llvm/Support/MathExtras.h"
23
24using namespace circt;
25using namespace comb;
26
27namespace circt {
28namespace comb {
29#define GEN_PASS_DEF_BALANCEMUX
30#include "circt/Dialect/Comb/Passes.h.inc"
31} // namespace comb
32} // namespace circt
33
34namespace {
35
36/// Mux chain with comparison folding pattern.
37class MuxChainWithComparison : public OpRewritePattern<MuxOp> {
38 unsigned muxChainThreshold;
39
40public:
41 // Set a higher benefit than PriorityEncoderReshape to run first.
42 MuxChainWithComparison(MLIRContext *context, unsigned muxChainThreshold)
43 : OpRewritePattern<MuxOp>(context, /*benefit=*/2),
44 muxChainThreshold(muxChainThreshold) {}
45 LogicalResult matchAndRewrite(MuxOp rootMux,
46 PatternRewriter &rewriter) const override {
47 auto fn = [muxChainThreshold = muxChainThreshold](size_t indexWidth,
48 size_t numEntries) {
49 // In this pattern, we consider it beneficial to fold mux chains
50 // with more than the threshold.
51 if (numEntries >= muxChainThreshold)
52 return MuxChainWithComparisonFoldingStyle::BalancedMuxTree;
53 return MuxChainWithComparisonFoldingStyle::None;
54 };
55 // Try folding on both false and true sides
56 return llvm::success(foldMuxChainWithComparison(rewriter, rootMux,
57 /*isFalseSide=*/true, fn) ||
58 foldMuxChainWithComparison(rewriter, rootMux,
59 /*isFalseSide=*/false, fn));
60 }
61};
62
63/// Rebalances a linear chain of muxes forming a priority encoder into a
64/// balanced tree structure. This reduces the depth of the mux tree from O(n)
65/// to O(log n).
66///
67/// For a priority encoder with n conditions, this transform:
68/// - Reduces depth from O(n) to O(log n) levels
69/// - Muxes: Creates exactly (n-1) muxes (same as original linear chain)
70/// - OR gates: Creates additional O(n log n) OR gates to combine
71class PriorityMuxReshape : public OpRewritePattern<MuxOp> {
72 unsigned muxChainThreshold;
73
74public:
75 PriorityMuxReshape(MLIRContext *context, unsigned muxChainThreshold)
76 : OpRewritePattern<MuxOp>(context, /*benefit=*/1),
77 muxChainThreshold(muxChainThreshold) {}
78
79 LogicalResult matchAndRewrite(MuxOp op,
80 PatternRewriter &rewriter) const override;
81
82private:
83 /// Helper function to collect a mux chain from a given side
84 std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Location>>
85 collectChain(MuxOp op, bool isFalseSide) const;
86
87 /// Build a balanced tree from the collected conditions and results
88 Value buildBalancedPriorityMux(PatternRewriter &rewriter,
89 ArrayRef<Value> conditions,
90 ArrayRef<Value> results, Value defaultValue,
91 ArrayRef<Location> locs) const;
92};
93} // namespace
94
95//===----------------------------------------------------------------------===//
96// Implementation
97//===----------------------------------------------------------------------===//
98
99LogicalResult
100PriorityMuxReshape::matchAndRewrite(MuxOp op, PatternRewriter &rewriter) const {
101 // Make sure that we're not looking at the intermediate node in a mux tree.
102 if (op->hasOneUse())
103 if (auto userMux = dyn_cast<MuxOp>(*op->user_begin()))
104 return failure();
105
106 // Early return if both or neither side are mux chains.
107 auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
108 auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
109 if ((trueMux && falseMux) || (!trueMux && !falseMux))
110 return failure();
111 bool useFalseSideChain = falseMux;
112
113 auto [conditions, results, locs] = collectChain(op, useFalseSideChain);
114 if (conditions.size() < muxChainThreshold)
115 return failure();
116
117 if (!useFalseSideChain) {
118 // For true-side chains, we need to invert all conditions
119 for (auto &cond : conditions) {
120 cond = rewriter.createOrFold<comb::XorOp>(
121 op.getLoc(), cond,
122 hw::ConstantOp::create(rewriter, op.getLoc(), APInt(1, 1)), true);
123 }
124 }
125
126 LDBG() << "Rebalanced priority mux with " << conditions.size()
127 << " conditions, using " << (useFalseSideChain ? "false" : "true")
128 << "-side chain.\n";
129
130 assert(conditions.size() + 1 == results.size() &&
131 "Expected one more result than conditions");
132 ArrayRef<Value> resultsRef(results);
133
134 // Build balanced tree and replace original op
135 Value balancedTree = buildBalancedPriorityMux(
136 rewriter, conditions, resultsRef.drop_back(), resultsRef.back(), locs);
137 replaceOpAndCopyNamehint(rewriter, op, balancedTree);
138 return success();
139}
140
141std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Location>>
142PriorityMuxReshape::collectChain(MuxOp op, bool isFalseSide) const {
143 SmallVector<Value> chainConditions, chainResults;
144 DenseSet<Value> seenConditions;
145 SmallVector<Location> chainLocs;
146
147 auto chainMux = isFalseSide ? op.getFalseValue().getDefiningOp<MuxOp>()
148 : op.getTrueValue().getDefiningOp<MuxOp>();
149
150 if (!chainMux)
151 return {chainConditions, chainResults, chainLocs};
152
153 // Helper lambdas to abstract the differences between false/true side chains
154 auto getChainResult = [&](MuxOp mux) -> Value {
155 return isFalseSide ? mux.getTrueValue() : mux.getFalseValue();
156 };
157
158 auto getChainNext = [&](MuxOp mux) -> Value {
159 return isFalseSide ? mux.getFalseValue() : mux.getTrueValue();
160 };
161
162 auto getRootResult = [&]() -> Value {
163 return isFalseSide ? op.getTrueValue() : op.getFalseValue();
164 };
165
166 // Start collecting the chain
167 seenConditions.insert(op.getCond());
168 chainConditions.push_back(op.getCond());
169 chainResults.push_back(getRootResult());
170 chainLocs.push_back(op.getLoc());
171
172 // Walk down the chain collecting all conditions and results
173 MuxOp currentMux = chainMux;
174 while (currentMux) {
175 // Only add unique conditions (outer muxes have priority)
176 if (seenConditions.insert(currentMux.getCond()).second) {
177 chainConditions.push_back(currentMux.getCond());
178 chainResults.push_back(getChainResult(currentMux));
179 chainLocs.push_back(currentMux.getLoc());
180 }
181
182 auto nextMux = getChainNext(currentMux).getDefiningOp<MuxOp>();
183 if (!nextMux || !nextMux->hasOneUse()) {
184 // Add the final default value
185 chainResults.push_back(getChainNext(currentMux));
186 break;
187 }
188 currentMux = nextMux;
189 }
190
191 return {chainConditions, chainResults, chainLocs};
192}
193
194// This function recursively constructs a balanced binary tree of muxes for a
195// priority encoder. It splits the conditions and results into halves,
196// combining the left half's conditions with an OR gate to select between
197// the left subtree (which includes the default for that half) and the right
198// subtree. This transforms a linear chain like:
199// a_0 ? r_0 : a_1 ? r_1 : ... : a_n ? r_n : default
200// into a balanced tree, reducing depth from O(n) to O(log n).
201// NOLINTNEXTLINE(misc-no-recursion)
202Value PriorityMuxReshape::buildBalancedPriorityMux(
203 PatternRewriter &rewriter, ArrayRef<Value> conditions,
204 ArrayRef<Value> results, Value defaultValue,
205 ArrayRef<Location> locs) const {
206 size_t size = conditions.size();
207 // Base cases.
208 if (size == 0)
209 return defaultValue;
210 if (size == 1)
211 return rewriter.createOrFold<MuxOp>(locs.front(), conditions.front(),
212 results.front(), defaultValue);
213
214 // Recursive case: split range in half. Take the ceiling to ensure the first
215 // half is larger.
216 // TODO: Ideally the separator index should be selected based on arrival times
217 // of results.
218 unsigned mid = llvm::divideCeil(size, 2);
219 assert(mid > 0);
220 auto loc = rewriter.getFusedLoc(ArrayRef<Location>(locs).take_front(mid));
221
222 // Build left and right subtrees. Use the last result as the default for the
223 // left subtree to ensure correct priority encoding.
224 Value leftTree = buildBalancedPriorityMux(
225 rewriter, conditions.take_front(mid), results.take_front(mid),
226 results.take_front(mid).back(), locs.take_front(mid));
227
228 Value rightTree = buildBalancedPriorityMux(
229 rewriter, conditions.drop_front(mid), results.drop_front(mid),
230 defaultValue, locs.drop_front(mid));
231
232 // Combine conditions from left half with OR
233 Value combinedCond =
234 rewriter.createOrFold<OrOp>(loc, conditions.take_front(mid), true);
235
236 // Create mux that selects between left and right subtrees
237 return MuxOp::create(rewriter, loc, combinedCond, leftTree, rightTree);
238}
239
240// or(mux(c_0, a_0, 0), mux(c_1, a_1, 0), ..., mux(c_n, a_n, 0)) ->
241// mux(c_0, a_0, mux(c_1, a_1, ...)) iff all conditions are independent.
242//
243// The mux tree should then be balanced by other MuxOp patterns.
244struct OrOfMuxToMuxChain : public OpRewritePattern<OrOp> {
246
247 LogicalResult matchAndRewrite(OrOp op,
248 PatternRewriter &rewriter) const override {
249 if (!op.getTwoState())
250 return failure();
251
252 SmallVector<Value> conditions;
253 for (Value operand : op.getOperands()) {
254 MuxOp mux = operand.getDefiningOp<MuxOp>();
255 if (!mux || !mux.getTwoState() ||
256 !matchPattern(mux.getFalseValue(), mlir::m_Zero()))
257 return failure();
258
259 conditions.push_back(mux.getCond());
260 }
261
262 if (!areConditionsIndependent(conditions))
263 return failure();
264
265 // Construct the mux chain from back to front.
266 SmallVector<Value> values(op.getOperands().drop_back());
267 Value v = op.getOperands().back();
268 while (!values.empty()) {
269 auto mux = values.pop_back_val().getDefiningOp<comb::MuxOp>();
270 v = comb::MuxOp::create(rewriter, op.getLoc(), mux.getCond(),
271 mux.getTrueValue(), v, /*twoState=*/true);
272 }
273 replaceOpAndCopyNamehint(rewriter, op, v);
274 return success();
275 }
276
277 // Returns true if the given i1 conditions are independent. That is, at most
278 // one condition can be true at a time.
279 //
280 // Currently we take a shortcut and check if the conditions are defined by
281 // ICmpEqs for different values. This is a common pattern in priority
282 // encoders.
283 bool areConditionsIndependent(ArrayRef<Value> conditions) const {
284 DenseSet<IntegerAttr> seenConstants;
285 for (Value v : conditions) {
286 auto icmp = v.getDefiningOp<ICmpOp>();
287 IntegerAttr value;
288 if (!icmp || icmp.getPredicate() != ICmpPredicate::eq ||
289 !matchPattern(icmp.getRhs(), mlir::m_Constant(&value)))
290 return false;
291
292 if (!seenConstants.insert(value).second)
293 return false;
294 }
295 return true;
296 }
297};
298
299/// Pass that performs enhanced mux chain optimizations
300struct BalanceMuxPass : public impl::BalanceMuxBase<BalanceMuxPass> {
301 using BalanceMuxBase::BalanceMuxBase;
302
303 void runOnOperation() override {
304 Operation *op = getOperation();
305 MLIRContext *context = op->getContext();
306
307 RewritePatternSet patterns(context);
308 patterns.add<MuxChainWithComparison, PriorityMuxReshape>(context,
309 muxChainThreshold);
311
312 if (failed(applyPatternsGreedily(op, std::move(patterns))))
313 return signalPassFailure();
314 }
315};
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
create(data_type, value)
Definition hw.py:433
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1
Pass that performs enhanced mux chain optimizations.
void runOnOperation() override
LogicalResult matchAndRewrite(OrOp op, PatternRewriter &rewriter) const override
bool areConditionsIndependent(ArrayRef< Value > conditions) const