CIRCT  20.0.0git
MuxToControlFlow.cpp
Go to the documentation of this file.
1 //===- MuxToControlFlow.cpp - Implement the MuxToControlFlow Pass ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implement a pass to convert muxes to control flow branches whenever it is
10 // beneficial for performance (i.e., when expected work avoided is more than
11 // branching costs)
12 //
13 //===----------------------------------------------------------------------===//
14 
19 #include "circt/Dialect/HW/HWOps.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Pass/Pass.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "arc-mux-to-control-flow"
25 
26 namespace circt {
27 namespace arc {
28 #define GEN_PASS_DEF_MUXTOCONTROLFLOW
29 #include "circt/Dialect/Arc/ArcPasses.h.inc"
30 } // namespace arc
31 } // namespace circt
32 
33 using namespace circt;
34 using namespace arc;
35 
36 //===----------------------------------------------------------------------===//
37 // MuxToControlFlow pass declarations
38 //===----------------------------------------------------------------------===//
39 
40 namespace {
41 
42 /// Convert muxes to if-statements.
43 struct MuxToControlFlowPass
44  : public arc::impl::MuxToControlFlowBase<MuxToControlFlowPass> {
45  MuxToControlFlowPass() = default;
46  MuxToControlFlowPass(const MuxToControlFlowPass &pass)
47  : MuxToControlFlowPass() {}
48 
49  void runOnOperation() override;
50 
51  Statistic numMuxesConverted{
52  this, "num-muxes-converted",
53  "Number of muxes that were converted to if-statements"};
54  Statistic numMuxesRetained{this, "num-muxes-retained",
55  "Number of muxes that were not converted"};
56 };
57 
58 /// Abstract over muxes to easy addition of support for other operations.
59 struct BranchInfo {
60  BranchInfo() = default;
61  BranchInfo(Value condition, Value trueValue, Value falseValue)
62  : condition(condition), trueValue(trueValue), falseValue(falseValue) {}
63 
64  Value condition;
65  Value trueValue;
66  Value falseValue;
67 
68  operator bool() { return condition && trueValue && falseValue; }
69 };
70 
71 } // namespace
72 
73 //===----------------------------------------------------------------------===//
74 // Helpers
75 //===----------------------------------------------------------------------===//
76 
77 /// Check whether @param curr is valid to be moved into the if-branch, which is
78 /// the stopping condition of the BFS traversal.
79 static bool isValidToProceedTraversal(Operation *mux, Operation *curr,
80  Value useValue,
81  SmallPtrSetImpl<Operation *> &visited) {
82  for (auto res : curr->getResults()) {
83  for (auto *user : res.getUsers()) {
84  // The use-sites of all results have to be within the same branch, thus
85  // already have to be visited already. The only exception is the first
86  // operation in the branch used by the mux itself.
87  if (!visited.contains(user) && user != mux)
88  return false;
89 
90  // The second part of the special case mentioned above (because otherwise
91  // we would also include the first operation of the other branch).
92  if (user == mux && res != useValue)
93  return false;
94 
95  if (user->getBlock() != curr->getBlock())
96  return false;
97  }
98  }
99 
100  return true;
101 }
102 
103 /// Compute the set of operations that would only be used in the branch
104 /// represented by @param useValue.
105 static void computeFanIn(Operation *mux, Value useValue,
106  SmallPtrSetImpl<Operation *> &visited) {
107  auto *op = useValue.getDefiningOp();
108  if (!op)
109  return;
110 
111  SmallVector<Operation *> worklist{op};
112 
113  while (!worklist.empty()) {
114  auto *curr = worklist.front();
115  worklist.erase(worklist.begin());
116 
117  if (visited.contains(curr))
118  continue;
119 
120  if (!isValidToProceedTraversal(mux, curr, useValue, visited))
121  continue;
122 
123  visited.insert(curr);
124 
125  for (auto val : curr->getOperands()) {
126  if (auto *defOp = val.getDefiningOp())
127  worklist.push_back(defOp);
128  }
129  }
130 }
131 
132 /// Clone ops that are used in both branches of an if-statement but not outside
133 /// of it. This is just here because of experimentation reasons. Doing this
134 /// might allow for better instruction scheduling to slightly reduce ISA
135 /// register pressure (however, it is currently too naive to only take the
136 /// beneficial situations), but it will increase binary size which is especially
137 /// bad when the hot part would otherwise fit in instruction cache (but doesn't
138 /// really matter when it doesn't fit anyways as there is no temporal locality
139 /// anyways).
140 [[maybe_unused]] static void
141 cloneOpsIntoBranchesWhenUsedInBoth(mlir::scf::IfOp ifOp) {
142  // Iterate over all operations at the same nesting level as the if-statement
143  // (not the operations inside the if-statement).
144  for (auto &op : llvm::reverse(*ifOp->getBlock())) {
145  if (op.getNumResults() == 0)
146  continue;
147 
148  // Collect all users of the current operations results.
149  SmallVector<Operation *> users;
150  for (auto result : op.getResults())
151  users.append(llvm::to_vector(result.getUsers()));
152 
153  auto parentsOfUsers =
154  llvm::map_range(users, [](auto user) { return user->getParentOp(); });
155 
156  auto allUsersNestedInIf = llvm::any_of(parentsOfUsers, [&](auto *parent) {
157  return !(isa<mlir::scf::IfOp>(parent) &&
158  parent->getBlock() == op.getBlock());
159  });
160 
161  // Check that all users of the results are nested inside the same scf.if
162  // operation
163  if (allUsersNestedInIf || !llvm::all_equal(parentsOfUsers))
164  continue;
165 
166  DenseMap<Region *, Value> cloneMap;
167  for (auto &use : llvm::make_early_inc_range(op.getUses())) {
168  auto *parentRegion = use.getOwner()->getParentRegion();
169  if (!cloneMap.count(parentRegion)) {
170  OpBuilder builder(&parentRegion->front().front());
171  cloneMap[parentRegion] = builder.clone(op)->getResult(0);
172  }
173  use.set(cloneMap[parentRegion]);
174  }
175  }
176 }
177 
178 /// Perform the actual conversion. Create the if-statement, move the operations
179 /// in its regions and delete the mux.
180 static void doConversion(Operation *op, BranchInfo info,
181  const SmallPtrSetImpl<Operation *> &thenOps,
182  const SmallPtrSetImpl<Operation *> &elseOps) {
183  if (op->getNumResults() != 1)
184  return;
185 
186  // Build the scf.if operation with the scf.yields inside.
187  ImplicitLocOpBuilder builder(op->getLoc(), op);
188  mlir::scf::IfOp ifOp = builder.create<mlir::scf::IfOp>(
189  info.condition,
190  [&](OpBuilder &builder, Location loc) {
191  builder.create<mlir::scf::YieldOp>(loc, info.trueValue);
192  },
193  [&](OpBuilder &builder, Location loc) {
194  builder.create<mlir::scf::YieldOp>(loc, info.falseValue);
195  });
196 
197  op->getResult(0).replaceAllUsesWith(ifOp.getResult(0));
198 
199  for (auto &ops :
200  llvm::make_early_inc_range(op->getParentRegion()->getOps())) {
201  // Move operations into the then-branch if they are only used in there.
202  // The original lexicographical order is preserved.
203  if (thenOps.contains(&ops))
204  ops.moveBefore(ifOp.thenBlock()->getTerminator());
205 
206  // Move operations into the else-branch if they are only used in there.
207  // The original lexicographical order is preserved.
208  if (elseOps.contains(&ops))
209  ops.moveBefore(ifOp.elseBlock()->getTerminator());
210  }
211 
212  op->erase();
213 
214  // NOTE: this is just here for some experimentation purposes
215  // cloneOpsIntoBranchesWhenUsedInBoth(ifOp);
216 }
217 
218 /// Simple helper to invoke the runtime cost interface for every operation in a
219 /// set and sum up the costs.
220 static uint32_t getCostEstimate(const SmallPtrSetImpl<Operation *> &ops) {
221  uint32_t cost = 0;
222 
223  for (auto *op : ops) {
224  if (auto *runtimeCostIF =
225  dyn_cast<RuntimeCostEstimateDialectInterface>(op->getDialect())) {
226  cost += runtimeCostIF->getCostEstimate(op);
227  } else {
228  LLVM_DEBUG(llvm::dbgs() << "No runtime cost estimate was provided for '"
229  << op->getName() << "', using default of 10\n");
230  cost += 10;
231  }
232  }
233 
234  return cost;
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // Decision functions (configure the pass here)
239 //===----------------------------------------------------------------------===//
240 
241 /// Convert concrete operations that should be converted to if-statements to a
242 /// more abstract representation the rest of the pass works with. This is the
243 /// place where support for more operations can be added (nothing else has to be
244 /// changed).
245 static BranchInfo getConversionInfo(Operation *op) {
246  if (auto mux = dyn_cast<comb::MuxOp>(op))
247  return BranchInfo{mux.getCond(), mux.getTrueValue(), mux.getFalseValue()};
248 
249  // TODO: we can also check for arith.select or other operations here
250 
251  return {};
252 }
253 
254 /// Use the cost measure of each branch to heuristically decide whether to
255 /// actually perform the conversion.
256 /// TODO: improve and fine-tune this
257 static bool isBeneficialToConvert(Operation *op,
258  const SmallPtrSetImpl<Operation *> &thenOps,
259  const SmallPtrSetImpl<Operation *> &elseOps) {
260  const uint32_t thenCost = getCostEstimate(thenOps);
261  const uint32_t elseCost = getCostEstimate(elseOps);
262 
263  // Due to the nature of mux sequences we need to make sure that a reasonable
264  // amount of operations stay in each if-branch because otherwise we end up
265  // with if-statements that only contain anther if-statement, which is usually
266  // more costly than keeping some muxes unconverted.
267  if (auto parent = op->getParentOfType<mlir::scf::IfOp>()) {
268  SmallPtrSet<Operation *, 32> ifBranchOps;
269 
270  for (auto &nestedOp : *op->getBlock()) {
271  if (!thenOps.contains(&nestedOp) && !elseOps.contains(&nestedOp))
272  ifBranchOps.insert(&nestedOp);
273  }
274 
275  if (getCostEstimate(ifBranchOps) < 100)
276  return false;
277  }
278 
279  // return thenCost + elseCost >= 100 && (thenCost == 0 || elseCost == 0);
280  return (thenCost >= 100 || thenCost == 0) &&
281  (elseCost >= 100 || elseCost == 0) &&
282  std::abs((int)thenCost - (int)elseCost) >= 100;
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // MuxToControlFlow pass definitions
287 //===----------------------------------------------------------------------===//
288 
289 // FIXME: Assumes that the regions in which muxes exist are topologically
290 // ordered.
291 // FIXME: does not consider side-effects
292 void MuxToControlFlowPass::runOnOperation() {
293  // Collect all operations that support the conversion to scf.if operations.
294  // Use 'walk' instead of 'getOps' as we also want to visit nested regions.
295  // We need to collect them because moving ops while iterating over them
296  // would require complicated iterator advancing/skipping but also tracking
297  // back to not miss supported operations.
298  SmallVector<Operation *> supportedOps;
299  getOperation()->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
300  // Skip ops with graph regions and ops that can contain ops with write
301  // semantics for now until side-effects and topological ordering is properly
302  // handled.
303  if (isa<hw::HWModuleOp, arc::ModelOp>(op))
304  return WalkResult::skip();
305 
306  if (getConversionInfo(op))
307  supportedOps.push_back(op);
308 
309  return WalkResult::advance();
310  });
311 
312  // We want to visit the operations bottom-up to visit the operations with the
313  // longest fan-in first. However, the other direction would also work with the
314  // current implementation.
315  for (auto *op : llvm::reverse(supportedOps)) {
316  auto info = getConversionInfo(op);
317 
318  // Compute the operations in the fan-in of each branch and use them to
319  // decide whether the operation should be converted.
320  // Stop at the first value that's also used outside of the branch.
321  llvm::SmallPtrSet<Operation *, 32> thenOps, elseOps;
322  computeFanIn(op, info.trueValue, thenOps);
323  computeFanIn(op, info.falseValue, elseOps);
324 
325  // Apply a cost measure to the operations in the branches and only convert
326  // when a performance increase can be expected.
327  if (isBeneficialToConvert(op, thenOps, elseOps)) {
328  doConversion(op, info, thenOps, elseOps);
329  ++numMuxesConverted;
330  } else {
331  ++numMuxesRetained;
332  }
333  }
334 }
335 
336 std::unique_ptr<Pass> arc::createMuxToControlFlowPass() {
337  return std::make_unique<MuxToControlFlowPass>();
338 }
static void doConversion(Operation *op, BranchInfo info, const SmallPtrSetImpl< Operation * > &thenOps, const SmallPtrSetImpl< Operation * > &elseOps)
Perform the actual conversion.
static uint32_t getCostEstimate(const SmallPtrSetImpl< Operation * > &ops)
Simple helper to invoke the runtime cost interface for every operation in a set and sum up the costs.
static bool isBeneficialToConvert(Operation *op, const SmallPtrSetImpl< Operation * > &thenOps, const SmallPtrSetImpl< Operation * > &elseOps)
Use the cost measure of each branch to heuristically decide whether to actually perform the conversio...
static void computeFanIn(Operation *mux, Value useValue, SmallPtrSetImpl< Operation * > &visited)
Compute the set of operations that would only be used in the branch represented by.
static bool isValidToProceedTraversal(Operation *mux, Operation *curr, Value useValue, SmallPtrSetImpl< Operation * > &visited)
Check whether.
static BranchInfo getConversionInfo(Operation *op)
Convert concrete operations that should be converted to if-statements to a more abstract representati...
static void cloneOpsIntoBranchesWhenUsedInBoth(mlir::scf::IfOp ifOp)
Clone ops that are used in both branches of an if-statement but not outside of it.
std::unique_ptr< mlir::Pass > createMuxToControlFlowPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21