CIRCT 20.0.0git
Loading...
Searching...
No Matches
SplitMerges.cpp
Go to the documentation of this file.
1//===- SplitMerges.cpp - handshake merge deconstruction pass --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the handshake merge deconstruction pass.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Rewrite/FrozenRewritePatternSet.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21namespace circt {
22namespace handshake {
23#define GEN_PASS_DEF_HANDSHAKESPLITMERGES
24#include "circt/Dialect/Handshake/HandshakePasses.h.inc"
25} // namespace handshake
26} // namespace circt
27
28using namespace circt;
29using namespace handshake;
30using namespace mlir;
31
32namespace {
33
34struct DeconstructMergePattern : public OpRewritePattern<handshake::MergeOp> {
35 using OpRewritePattern::OpRewritePattern;
36
37 LogicalResult matchAndRewrite(handshake::MergeOp mergeOp,
38 PatternRewriter &rewriter) const override {
39 if (mergeOp.getNumOperands() <= 2)
40 return failure();
41
42 llvm::SmallVector<Value> mergeInputs;
43 llvm::copy(mergeOp.getOperands(), std::back_inserter(mergeInputs));
44
45 // Recursively build a balanced 2-input merge tree.
46 while (mergeInputs.size() > 1) {
47 llvm::SmallVector<Value> newMergeInputs;
48 for (unsigned i = 0, e = mergeInputs.size(); i < ((e / 2) * 2); i += 2) {
49 auto cm2 = rewriter.create<handshake::MergeOp>(
50 mergeOp.getLoc(), ValueRange{mergeInputs[i], mergeInputs[i + 1]});
51 newMergeInputs.push_back(cm2.getResult());
52 }
53 if (mergeInputs.size() % 2 != 0)
54 newMergeInputs.push_back(mergeInputs.back());
55
56 mergeInputs = newMergeInputs;
57 }
58
59 assert(mergeInputs.size() == 1);
60 rewriter.replaceOp(mergeOp, mergeInputs[0]);
61
62 return success();
63 }
64};
65
66struct DeconstructCMergePattern
67 : public OpRewritePattern<handshake::ControlMergeOp> {
68 using OpRewritePattern::OpRewritePattern;
69
70 LogicalResult matchAndRewrite(handshake::ControlMergeOp cmergeOp,
71 PatternRewriter &rewriter) const override {
72 if (cmergeOp.getNumOperands() <= 2)
73 return failure();
74
75 Type cmergeIndexType = cmergeOp.getIndex().getType();
76 auto loc = cmergeOp.getLoc();
77
78 // Function for create a cmerge-pack structure which generates a
79 // tuple<index, data> from two operands and an index offset.
80 auto mergeTwoOperands = [&](Value op0, Value op1,
81 unsigned idxOffset) -> Value {
82 auto cm2 = rewriter.create<handshake::ControlMergeOp>(
83 loc, ValueRange{op0, op1}, cmergeIndexType);
84 Value idxOperand = cm2.getIndex();
85 if (idxOffset != 0) {
86 // Non-zero index offset; add it to the index operand.
87 idxOperand = rewriter.create<arith::AddIOp>(
88 loc, idxOperand,
89 rewriter.create<arith::ConstantOp>(
90 loc, rewriter.getIntegerAttr(cmergeIndexType, idxOffset)));
91 }
92
93 // Pack index and data into a tuple s.t. they share control.
94 return rewriter.create<handshake::PackOp>(
95 loc, ValueRange{cm2.getResult(), idxOperand});
96 };
97
98 llvm::SmallVector<Value> packedTuples;
99 // Perform the two-operand merges.
100 for (unsigned i = 0, e = cmergeOp.getNumOperands(); i < ((e / 2) * 2);
101 i += 2) {
102 packedTuples.push_back(mergeTwoOperands(cmergeOp.getOperand(i),
103 cmergeOp.getOperand(i + 1), i));
104 }
105 if (cmergeOp.getNumOperands() % 2 != 0) {
106 // If there is an odd number of operands, the last operand becomes a tuple
107 // of itself with an index of the number of operands - 1.
108 unsigned lastIdx = cmergeOp.getNumOperands() - 1;
109 packedTuples.push_back(rewriter.create<handshake::PackOp>(
110 loc, ValueRange{cmergeOp.getOperand(lastIdx),
111 rewriter.create<arith::ConstantOp>(
112 loc, rewriter.getIntegerAttr(cmergeIndexType,
113 lastIdx))}));
114 }
115
116 // Non-deterministically merge the tuples and unpack the result.
117 auto mergedTuple =
118 rewriter.create<handshake::MergeOp>(loc, ValueRange(packedTuples));
119
120 // And finally, replace the original cmerge with the unpacked result.
121 rewriter.replaceOpWithNewOp<handshake::UnpackOp>(cmergeOp,
122 mergedTuple.getResult());
123 return success();
124 }
125};
126
127struct HandshakeSplitMerges
128 : public circt::handshake::impl::HandshakeSplitMergesBase<
129 HandshakeSplitMerges> {
130 void runOnOperation() override {
131 RewritePatternSet patterns(&getContext());
132 patterns.insert<DeconstructCMergePattern, DeconstructMergePattern>(
133 &getContext());
134
135 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
136 signalPassFailure();
137 };
138};
139} // namespace
140
142 return std::make_unique<HandshakeSplitMerges>();
143}
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createHandshakeSplitMergesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.