11 #include "mlir/Interfaces/SideEffectInterfaces.h"
12 #include "mlir/Pass/Pass.h"
13 #include "llvm/Support/Debug.h"
15 #define DEBUG_TYPE "arc-simplify-variadic-ops"
19 #define GEN_PASS_DEF_SIMPLIFYVARIADICOPS
20 #include "circt/Dialect/Arc/ArcPasses.h.inc"
25 using namespace circt;
30 struct SimplifyVariadicOpsPass
31 :
public arc::impl::SimplifyVariadicOpsBase<SimplifyVariadicOpsPass> {
32 SimplifyVariadicOpsPass() =
default;
33 SimplifyVariadicOpsPass(
const SimplifyVariadicOpsPass &pass)
34 : SimplifyVariadicOpsPass() {}
36 void runOnOperation()
override;
37 void simplifyOp(Operation *op);
41 void SimplifyVariadicOpsPass::runOnOperation() {
42 SmallVector<Operation *> opsToProcess;
43 getOperation().walk([&](Operation *op) {
44 if (op->hasTrait<OpTrait::IsCommutative>() && op->getNumRegions() == 0 &&
45 op->getNumSuccessors() == 0 && op->getNumResults() == 1 &&
46 op->getNumOperands() > 2 && isMemoryEffectFree(op))
47 opsToProcess.push_back(op);
49 for (
auto *op : opsToProcess)
53 void SimplifyVariadicOpsPass::simplifyOp(Operation *op) {
58 auto *block = op->getBlock();
59 SmallVector<Value> operands;
60 for (
auto operand : op->getOperands()) {
61 if (
auto blockArg = dyn_cast<BlockArgument>(operand)) {
62 if (blockArg.getOwner() != block) {
63 ++numOpsSkippedMultipleBlocks;
67 auto *defOp = operand.getDefiningOp();
68 if (defOp->getBlock() != block) {
69 ++numOpsSkippedMultipleBlocks;
73 operands.push_back(operand);
75 LLVM_DEBUG(llvm::dbgs() <<
"Simplifying " << *op <<
"\n");
79 llvm::sort(operands, [](
auto a,
auto b) {
81 auto aBlockArg = dyn_cast<BlockArgument>(a);
82 auto bBlockArg = dyn_cast<BlockArgument>(b);
83 if (aBlockArg && bBlockArg)
84 return aBlockArg.getArgNumber() < bBlockArg.getArgNumber();
87 auto *aOp = a.getDefiningOp();
88 auto *bOp = b.getDefiningOp();
93 return aOp->isBeforeInBlock(bOp);
95 LLVM_DEBUG(
for (
auto value
96 : operands) llvm::dbgs()
97 <<
"- " << value <<
"\n";);
100 for (
auto [a, b] : llvm::zip(operands, op->getOperands())) {
109 Value reduced = operands[0];
110 auto builder = OpBuilder::atBlockBegin(block);
111 for (
auto value : llvm::drop_begin(operands)) {
112 if (
auto *defOp = value.getDefiningOp())
113 builder.setInsertionPointAfter(defOp);
115 .create(op->getLoc(), op->getName().getIdentifier(),
116 ValueRange{reduced, value}, op->getResultTypes(),
121 op->getResult(0).replaceAllUsesWith(reduced);
127 return std::make_unique<SimplifyVariadicOpsPass>();
std::unique_ptr< mlir::Pass > createSimplifyVariadicOpsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.