CIRCT 20.0.0git
Loading...
Searching...
No Matches
MuxToControlFlow.cpp
Go to the documentation of this file.
1//===- MuxToControlFlow.cpp - Implement the MuxToControlFlow Pass ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implement a pass to convert muxes to control flow branches whenever it is
10// beneficial for performance (i.e., when expected work avoided is more than
11// branching costs)
12//
13//===----------------------------------------------------------------------===//
14
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/Support/Debug.h"
23
24#define DEBUG_TYPE "arc-mux-to-control-flow"
25
26namespace circt {
27namespace arc {
28#define GEN_PASS_DEF_MUXTOCONTROLFLOW
29#include "circt/Dialect/Arc/ArcPasses.h.inc"
30} // namespace arc
31} // namespace circt
32
33using namespace circt;
34using namespace arc;
35
36//===----------------------------------------------------------------------===//
37// MuxToControlFlow pass declarations
38//===----------------------------------------------------------------------===//
39
40namespace {
41
42/// Convert muxes to if-statements.
43struct MuxToControlFlowPass
44 : public arc::impl::MuxToControlFlowBase<MuxToControlFlowPass> {
45 MuxToControlFlowPass() = default;
46 MuxToControlFlowPass(const MuxToControlFlowPass &pass)
47 : MuxToControlFlowPass() {}
48
49 void runOnOperation() override;
50
51 Statistic numMuxesConverted{
52 this, "num-muxes-converted",
53 "Number of muxes that were converted to if-statements"};
54 Statistic numMuxesRetained{this, "num-muxes-retained",
55 "Number of muxes that were not converted"};
56};
57
58/// Abstract over muxes to easy addition of support for other operations.
59struct BranchInfo {
60 BranchInfo() = default;
61 BranchInfo(Value condition, Value trueValue, Value falseValue)
62 : condition(condition), trueValue(trueValue), falseValue(falseValue) {}
63
64 Value condition;
65 Value trueValue;
66 Value falseValue;
67
68 operator bool() { return condition && trueValue && falseValue; }
69};
70
71} // namespace
72
73//===----------------------------------------------------------------------===//
74// Helpers
75//===----------------------------------------------------------------------===//
76
77/// Check whether @param curr is valid to be moved into the if-branch, which is
78/// the stopping condition of the BFS traversal.
79static bool isValidToProceedTraversal(Operation *mux, Operation *curr,
80 Value useValue,
81 SmallPtrSetImpl<Operation *> &visited) {
82 for (auto res : curr->getResults()) {
83 for (auto *user : res.getUsers()) {
84 // The use-sites of all results have to be within the same branch, thus
85 // already have to be visited already. The only exception is the first
86 // operation in the branch used by the mux itself.
87 if (!visited.contains(user) && user != mux)
88 return false;
89
90 // The second part of the special case mentioned above (because otherwise
91 // we would also include the first operation of the other branch).
92 if (user == mux && res != useValue)
93 return false;
94
95 if (user->getBlock() != curr->getBlock())
96 return false;
97 }
98 }
99
100 return true;
101}
102
103/// Compute the set of operations that would only be used in the branch
104/// represented by @param useValue.
105static void computeFanIn(Operation *mux, Value useValue,
106 SmallPtrSetImpl<Operation *> &visited) {
107 auto *op = useValue.getDefiningOp();
108 if (!op)
109 return;
110
111 SmallVector<Operation *> worklist{op};
112
113 while (!worklist.empty()) {
114 auto *curr = worklist.front();
115 worklist.erase(worklist.begin());
116
117 if (visited.contains(curr))
118 continue;
119
120 if (!isValidToProceedTraversal(mux, curr, useValue, visited))
121 continue;
122
123 visited.insert(curr);
124
125 for (auto val : curr->getOperands()) {
126 if (auto *defOp = val.getDefiningOp())
127 worklist.push_back(defOp);
128 }
129 }
130}
131
132/// Clone ops that are used in both branches of an if-statement but not outside
133/// of it. This is just here because of experimentation reasons. Doing this
134/// might allow for better instruction scheduling to slightly reduce ISA
135/// register pressure (however, it is currently too naive to only take the
136/// beneficial situations), but it will increase binary size which is especially
137/// bad when the hot part would otherwise fit in instruction cache (but doesn't
138/// really matter when it doesn't fit anyways as there is no temporal locality
139/// anyways).
140[[maybe_unused]] static void
142 // Iterate over all operations at the same nesting level as the if-statement
143 // (not the operations inside the if-statement).
144 for (auto &op : llvm::reverse(*ifOp->getBlock())) {
145 if (op.getNumResults() == 0)
146 continue;
147
148 // Collect all users of the current operations results.
149 SmallVector<Operation *> users;
150 for (auto result : op.getResults())
151 users.append(llvm::to_vector(result.getUsers()));
152
153 auto parentsOfUsers =
154 llvm::map_range(users, [](auto user) { return user->getParentOp(); });
155
156 auto allUsersNestedInIf = llvm::any_of(parentsOfUsers, [&](auto *parent) {
157 return !(isa<mlir::scf::IfOp>(parent) &&
158 parent->getBlock() == op.getBlock());
159 });
160
161 // Check that all users of the results are nested inside the same scf.if
162 // operation
163 if (allUsersNestedInIf || !llvm::all_equal(parentsOfUsers))
164 continue;
165
166 DenseMap<Region *, Value> cloneMap;
167 for (auto &use : llvm::make_early_inc_range(op.getUses())) {
168 auto *parentRegion = use.getOwner()->getParentRegion();
169 if (!cloneMap.count(parentRegion)) {
170 OpBuilder builder(&parentRegion->front().front());
171 cloneMap[parentRegion] = builder.clone(op)->getResult(0);
172 }
173 use.set(cloneMap[parentRegion]);
174 }
175 }
176}
177
178/// Perform the actual conversion. Create the if-statement, move the operations
179/// in its regions and delete the mux.
180static void doConversion(Operation *op, BranchInfo info,
181 const SmallPtrSetImpl<Operation *> &thenOps,
182 const SmallPtrSetImpl<Operation *> &elseOps) {
183 if (op->getNumResults() != 1)
184 return;
185
186 // Build the scf.if operation with the scf.yields inside.
187 ImplicitLocOpBuilder builder(op->getLoc(), op);
188 mlir::scf::IfOp ifOp = builder.create<mlir::scf::IfOp>(
189 info.condition,
190 [&](OpBuilder &builder, Location loc) {
191 builder.create<mlir::scf::YieldOp>(loc, info.trueValue);
192 },
193 [&](OpBuilder &builder, Location loc) {
194 builder.create<mlir::scf::YieldOp>(loc, info.falseValue);
195 });
196
197 op->getResult(0).replaceAllUsesWith(ifOp.getResult(0));
198
199 for (auto &ops :
200 llvm::make_early_inc_range(op->getParentRegion()->getOps())) {
201 // Move operations into the then-branch if they are only used in there.
202 // The original lexicographical order is preserved.
203 if (thenOps.contains(&ops))
204 ops.moveBefore(ifOp.thenBlock()->getTerminator());
205
206 // Move operations into the else-branch if they are only used in there.
207 // The original lexicographical order is preserved.
208 if (elseOps.contains(&ops))
209 ops.moveBefore(ifOp.elseBlock()->getTerminator());
210 }
211
212 op->erase();
213
214 // NOTE: this is just here for some experimentation purposes
215 // cloneOpsIntoBranchesWhenUsedInBoth(ifOp);
216}
217
218/// Simple helper to invoke the runtime cost interface for every operation in a
219/// set and sum up the costs.
220static uint32_t getCostEstimate(const SmallPtrSetImpl<Operation *> &ops) {
221 uint32_t cost = 0;
222
223 for (auto *op : ops) {
224 if (auto *runtimeCostIF =
225 dyn_cast<RuntimeCostEstimateDialectInterface>(op->getDialect())) {
226 cost += runtimeCostIF->getCostEstimate(op);
227 } else {
228 LLVM_DEBUG(llvm::dbgs() << "No runtime cost estimate was provided for '"
229 << op->getName() << "', using default of 10\n");
230 cost += 10;
231 }
232 }
233
234 return cost;
235}
236
237//===----------------------------------------------------------------------===//
238// Decision functions (configure the pass here)
239//===----------------------------------------------------------------------===//
240
241/// Convert concrete operations that should be converted to if-statements to a
242/// more abstract representation the rest of the pass works with. This is the
243/// place where support for more operations can be added (nothing else has to be
244/// changed).
245static BranchInfo getConversionInfo(Operation *op) {
246 if (auto mux = dyn_cast<comb::MuxOp>(op))
247 return BranchInfo{mux.getCond(), mux.getTrueValue(), mux.getFalseValue()};
248
249 // TODO: we can also check for arith.select or other operations here
250
251 return {};
252}
253
254/// Use the cost measure of each branch to heuristically decide whether to
255/// actually perform the conversion.
256/// TODO: improve and fine-tune this
257static bool isBeneficialToConvert(Operation *op,
258 const SmallPtrSetImpl<Operation *> &thenOps,
259 const SmallPtrSetImpl<Operation *> &elseOps) {
260 const uint32_t thenCost = getCostEstimate(thenOps);
261 const uint32_t elseCost = getCostEstimate(elseOps);
262
263 // Due to the nature of mux sequences we need to make sure that a reasonable
264 // amount of operations stay in each if-branch because otherwise we end up
265 // with if-statements that only contain anther if-statement, which is usually
266 // more costly than keeping some muxes unconverted.
267 if (auto parent = op->getParentOfType<mlir::scf::IfOp>()) {
268 SmallPtrSet<Operation *, 32> ifBranchOps;
269
270 for (auto &nestedOp : *op->getBlock()) {
271 if (!thenOps.contains(&nestedOp) && !elseOps.contains(&nestedOp))
272 ifBranchOps.insert(&nestedOp);
273 }
274
275 if (getCostEstimate(ifBranchOps) < 100)
276 return false;
277 }
278
279 // return thenCost + elseCost >= 100 && (thenCost == 0 || elseCost == 0);
280 return (thenCost >= 100 || thenCost == 0) &&
281 (elseCost >= 100 || elseCost == 0) &&
282 std::abs((int)thenCost - (int)elseCost) >= 100;
283}
284
285//===----------------------------------------------------------------------===//
286// MuxToControlFlow pass definitions
287//===----------------------------------------------------------------------===//
288
289// FIXME: Assumes that the regions in which muxes exist are topologically
290// ordered.
291// FIXME: does not consider side-effects
292void MuxToControlFlowPass::runOnOperation() {
293 // Collect all operations that support the conversion to scf.if operations.
294 // Use 'walk' instead of 'getOps' as we also want to visit nested regions.
295 // We need to collect them because moving ops while iterating over them
296 // would require complicated iterator advancing/skipping but also tracking
297 // back to not miss supported operations.
298 SmallVector<Operation *> supportedOps;
299 getOperation()->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
300 // Skip ops with graph regions and ops that can contain ops with write
301 // semantics for now until side-effects and topological ordering is properly
302 // handled.
303 if (isa<hw::HWModuleOp, arc::ModelOp>(op))
304 return WalkResult::skip();
305
306 if (getConversionInfo(op))
307 supportedOps.push_back(op);
308
309 return WalkResult::advance();
310 });
311
312 // We want to visit the operations bottom-up to visit the operations with the
313 // longest fan-in first. However, the other direction would also work with the
314 // current implementation.
315 for (auto *op : llvm::reverse(supportedOps)) {
316 auto info = getConversionInfo(op);
317
318 // Compute the operations in the fan-in of each branch and use them to
319 // decide whether the operation should be converted.
320 // Stop at the first value that's also used outside of the branch.
321 llvm::SmallPtrSet<Operation *, 32> thenOps, elseOps;
322 computeFanIn(op, info.trueValue, thenOps);
323 computeFanIn(op, info.falseValue, elseOps);
324
325 // Apply a cost measure to the operations in the branches and only convert
326 // when a performance increase can be expected.
327 if (isBeneficialToConvert(op, thenOps, elseOps)) {
328 doConversion(op, info, thenOps, elseOps);
329 ++numMuxesConverted;
330 } else {
331 ++numMuxesRetained;
332 }
333 }
334}
335
336std::unique_ptr<Pass> arc::createMuxToControlFlowPass() {
337 return std::make_unique<MuxToControlFlowPass>();
338}
static void doConversion(Operation *op, BranchInfo info, const SmallPtrSetImpl< Operation * > &thenOps, const SmallPtrSetImpl< Operation * > &elseOps)
Perform the actual conversion.
static uint32_t getCostEstimate(const SmallPtrSetImpl< Operation * > &ops)
Simple helper to invoke the runtime cost interface for every operation in a set and sum up the costs.
static bool isBeneficialToConvert(Operation *op, const SmallPtrSetImpl< Operation * > &thenOps, const SmallPtrSetImpl< Operation * > &elseOps)
Use the cost measure of each branch to heuristically decide whether to actually perform the conversio...
static void computeFanIn(Operation *mux, Value useValue, SmallPtrSetImpl< Operation * > &visited)
Compute the set of operations that would only be used in the branch represented by.
static bool isValidToProceedTraversal(Operation *mux, Operation *curr, Value useValue, SmallPtrSetImpl< Operation * > &visited)
Check whether.
static BranchInfo getConversionInfo(Operation *op)
Convert concrete operations that should be converted to if-statements to a more abstract representati...
static void cloneOpsIntoBranchesWhenUsedInBoth(mlir::scf::IfOp ifOp)
Clone ops that are used in both branches of an if-statement but not outside of it.
std::unique_ptr< mlir::Pass > createMuxToControlFlowPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.