CIRCT  20.0.0git
SplitMerges.cpp
Go to the documentation of this file.
1 //===- SplitMerges.cpp - handshake merge deconstruction pass --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the handshake merge deconstruction pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 namespace circt {
22 namespace handshake {
23 #define GEN_PASS_DEF_HANDSHAKESPLITMERGES
24 #include "circt/Dialect/Handshake/HandshakePasses.h.inc"
25 } // namespace handshake
26 } // namespace circt
27 
28 using namespace circt;
29 using namespace handshake;
30 using namespace mlir;
31 
32 namespace {
33 
34 struct DeconstructMergePattern : public OpRewritePattern<handshake::MergeOp> {
35  using OpRewritePattern::OpRewritePattern;
36 
37  LogicalResult matchAndRewrite(handshake::MergeOp mergeOp,
38  PatternRewriter &rewriter) const override {
39  if (mergeOp.getNumOperands() <= 2)
40  return failure();
41 
42  llvm::SmallVector<Value> mergeInputs;
43  llvm::copy(mergeOp.getOperands(), std::back_inserter(mergeInputs));
44 
45  // Recursively build a balanced 2-input merge tree.
46  while (mergeInputs.size() > 1) {
47  llvm::SmallVector<Value> newMergeInputs;
48  for (unsigned i = 0, e = mergeInputs.size(); i < ((e / 2) * 2); i += 2) {
49  auto cm2 = rewriter.create<handshake::MergeOp>(
50  mergeOp.getLoc(), ValueRange{mergeInputs[i], mergeInputs[i + 1]});
51  newMergeInputs.push_back(cm2.getResult());
52  }
53  if (mergeInputs.size() % 2 != 0)
54  newMergeInputs.push_back(mergeInputs.back());
55 
56  mergeInputs = newMergeInputs;
57  }
58 
59  assert(mergeInputs.size() == 1);
60  rewriter.replaceOp(mergeOp, mergeInputs[0]);
61 
62  return success();
63  }
64 };
65 
66 struct DeconstructCMergePattern
67  : public OpRewritePattern<handshake::ControlMergeOp> {
68  using OpRewritePattern::OpRewritePattern;
69 
70  LogicalResult matchAndRewrite(handshake::ControlMergeOp cmergeOp,
71  PatternRewriter &rewriter) const override {
72  if (cmergeOp.getNumOperands() <= 2)
73  return failure();
74 
75  Type cmergeIndexType = cmergeOp.getIndex().getType();
76  auto loc = cmergeOp.getLoc();
77 
78  // Function for create a cmerge-pack structure which generates a
79  // tuple<index, data> from two operands and an index offset.
80  auto mergeTwoOperands = [&](Value op0, Value op1,
81  unsigned idxOffset) -> Value {
82  auto cm2 = rewriter.create<handshake::ControlMergeOp>(
83  loc, ValueRange{op0, op1}, cmergeIndexType);
84  Value idxOperand = cm2.getIndex();
85  if (idxOffset != 0) {
86  // Non-zero index offset; add it to the index operand.
87  idxOperand = rewriter.create<arith::AddIOp>(
88  loc, idxOperand,
89  rewriter.create<arith::ConstantOp>(
90  loc, rewriter.getIntegerAttr(cmergeIndexType, idxOffset)));
91  }
92 
93  // Pack index and data into a tuple s.t. they share control.
94  return rewriter.create<handshake::PackOp>(
95  loc, ValueRange{cm2.getResult(), idxOperand});
96  };
97 
98  llvm::SmallVector<Value> packedTuples;
99  // Perform the two-operand merges.
100  for (unsigned i = 0, e = cmergeOp.getNumOperands(); i < ((e / 2) * 2);
101  i += 2) {
102  packedTuples.push_back(mergeTwoOperands(cmergeOp.getOperand(i),
103  cmergeOp.getOperand(i + 1), i));
104  }
105  if (cmergeOp.getNumOperands() % 2 != 0) {
106  // If there is an odd number of operands, the last operand becomes a tuple
107  // of itself with an index of the number of operands - 1.
108  unsigned lastIdx = cmergeOp.getNumOperands() - 1;
109  packedTuples.push_back(rewriter.create<handshake::PackOp>(
110  loc, ValueRange{cmergeOp.getOperand(lastIdx),
111  rewriter.create<arith::ConstantOp>(
112  loc, rewriter.getIntegerAttr(cmergeIndexType,
113  lastIdx))}));
114  }
115 
116  // Non-deterministically merge the tuples and unpack the result.
117  auto mergedTuple =
118  rewriter.create<handshake::MergeOp>(loc, ValueRange(packedTuples));
119 
120  // And finally, replace the original cmerge with the unpacked result.
121  rewriter.replaceOpWithNewOp<handshake::UnpackOp>(cmergeOp,
122  mergedTuple.getResult());
123  return success();
124  }
125 };
126 
127 struct HandshakeSplitMerges
128  : public circt::handshake::impl::HandshakeSplitMergesBase<
129  HandshakeSplitMerges> {
130  void runOnOperation() override {
131  RewritePatternSet patterns(&getContext());
132  patterns.insert<DeconstructCMergePattern, DeconstructMergePattern>(
133  &getContext());
134 
135  if (failed(
136  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
137  signalPassFailure();
138  };
139 };
140 } // namespace
141 
143  return std::make_unique<HandshakeSplitMerges>();
144 }
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createHandshakeSplitMergesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21