CIRCT  19.0.0git
SimplifyVariadicOps.cpp
Go to the documentation of this file.
1 //===- SimplifyVariadicOps.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Interfaces/SideEffectInterfaces.h"
12 #include "mlir/Pass/Pass.h"
13 #include "llvm/Support/Debug.h"
14 
15 #define DEBUG_TYPE "arc-simplify-variadic-ops"
16 
17 namespace circt {
18 namespace arc {
19 #define GEN_PASS_DEF_SIMPLIFYVARIADICOPS
20 #include "circt/Dialect/Arc/ArcPasses.h.inc"
21 } // namespace arc
22 } // namespace circt
23 
24 using namespace mlir;
25 using namespace circt;
26 using namespace arc;
27 using namespace hw;
28 
29 namespace {
30 struct SimplifyVariadicOpsPass
31  : public arc::impl::SimplifyVariadicOpsBase<SimplifyVariadicOpsPass> {
32  SimplifyVariadicOpsPass() = default;
33  SimplifyVariadicOpsPass(const SimplifyVariadicOpsPass &pass)
34  : SimplifyVariadicOpsPass() {}
35 
36  void runOnOperation() override;
37  void simplifyOp(Operation *op);
38 };
39 } // namespace
40 
41 void SimplifyVariadicOpsPass::runOnOperation() {
42  SmallVector<Operation *> opsToProcess;
43  getOperation().walk([&](Operation *op) {
44  if (op->hasTrait<OpTrait::IsCommutative>() && op->getNumRegions() == 0 &&
45  op->getNumSuccessors() == 0 && op->getNumResults() == 1 &&
46  op->getNumOperands() > 2 && isMemoryEffectFree(op))
47  opsToProcess.push_back(op);
48  });
49  for (auto *op : opsToProcess)
50  simplifyOp(op);
51 }
52 
53 void SimplifyVariadicOpsPass::simplifyOp(Operation *op) {
54  // Gather the list of operands together with the defining op. Block arguments
55  // simply get no op assigned. This is also where we bail out if the block
56  // argument or any of the defining ops is in a different block than the op
57  // itself.
58  auto *block = op->getBlock();
59  SmallVector<Value> operands;
60  for (auto operand : op->getOperands()) {
61  if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
62  if (blockArg.getOwner() != block) {
63  ++numOpsSkippedMultipleBlocks;
64  return;
65  }
66  } else {
67  auto *defOp = operand.getDefiningOp();
68  if (defOp->getBlock() != block) {
69  ++numOpsSkippedMultipleBlocks;
70  return;
71  }
72  }
73  operands.push_back(operand);
74  }
75  LLVM_DEBUG(llvm::dbgs() << "Simplifying " << *op << "\n");
76 
77  // Sort the list of operands based on the order in which their defining ops
78  // appear in the block.
79  llvm::sort(operands, [](auto a, auto b) {
80  // Sort block args by the arg number.
81  auto aBlockArg = dyn_cast<BlockArgument>(a);
82  auto bBlockArg = dyn_cast<BlockArgument>(b);
83  if (aBlockArg && bBlockArg)
84  return aBlockArg.getArgNumber() < bBlockArg.getArgNumber();
85 
86  // Sort other values by block order of the defining op.
87  auto *aOp = a.getDefiningOp();
88  auto *bOp = b.getDefiningOp();
89  if (!aOp)
90  return true;
91  if (!bOp)
92  return false;
93  return aOp->isBeforeInBlock(bOp);
94  });
95  LLVM_DEBUG(for (auto value
96  : operands) llvm::dbgs()
97  << "- " << value << "\n";);
98 
99  // Keep some statistics whether we actually did do some reordering.
100  for (auto [a, b] : llvm::zip(operands, op->getOperands())) {
101  if (a != b) {
102  ++numOpsReordered;
103  break;
104  }
105  }
106 
107  // Split up the variadic operation by going through the operands and creating
108  // pairwise versions of the op as close as possible to the operands.
109  Value reduced = operands[0];
110  auto builder = OpBuilder::atBlockBegin(block);
111  for (auto value : llvm::drop_begin(operands)) {
112  if (auto *defOp = value.getDefiningOp())
113  builder.setInsertionPointAfter(defOp);
114  reduced = builder
115  .create(op->getLoc(), op->getName().getIdentifier(),
116  ValueRange{reduced, value}, op->getResultTypes(),
117  op->getAttrs())
118  ->getResult(0);
119  ++numOpsCreated;
120  }
121  op->getResult(0).replaceAllUsesWith(reduced);
122  op->erase();
123  ++numOpsSimplified;
124 }
125 
126 std::unique_ptr<Pass> arc::createSimplifyVariadicOpsPass() {
127  return std::make_unique<SimplifyVariadicOpsPass>();
128 }
Builder builder
std::unique_ptr< mlir::Pass > createSimplifyVariadicOpsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1