CIRCT 22.0.0git
Loading...
Searching...
No Matches
SplitMerges.cpp
Go to the documentation of this file.
1//===- SplitMerges.cpp - handshake merge deconstruction pass --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the handshake merge deconstruction pass.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Rewrite/FrozenRewritePatternSet.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21namespace circt {
22namespace handshake {
23#define GEN_PASS_DEF_HANDSHAKESPLITMERGES
24#include "circt/Dialect/Handshake/HandshakePasses.h.inc"
25} // namespace handshake
26} // namespace circt
27
28using namespace circt;
29using namespace handshake;
30using namespace mlir;
31
32namespace {
33
34struct DeconstructMergePattern : public OpRewritePattern<handshake::MergeOp> {
35 using OpRewritePattern::OpRewritePattern;
36
37 LogicalResult matchAndRewrite(handshake::MergeOp mergeOp,
38 PatternRewriter &rewriter) const override {
39 if (mergeOp.getNumOperands() <= 2)
40 return failure();
41
42 llvm::SmallVector<Value> mergeInputs;
43 llvm::copy(mergeOp.getOperands(), std::back_inserter(mergeInputs));
44
45 // Recursively build a balanced 2-input merge tree.
46 while (mergeInputs.size() > 1) {
47 llvm::SmallVector<Value> newMergeInputs;
48 for (unsigned i = 0, e = mergeInputs.size(); i < ((e / 2) * 2); i += 2) {
49 auto cm2 = handshake::MergeOp::create(
50 rewriter, mergeOp.getLoc(),
51 ValueRange{mergeInputs[i], mergeInputs[i + 1]});
52 newMergeInputs.push_back(cm2.getResult());
53 }
54 if (mergeInputs.size() % 2 != 0)
55 newMergeInputs.push_back(mergeInputs.back());
56
57 mergeInputs = newMergeInputs;
58 }
59
60 assert(mergeInputs.size() == 1);
61 rewriter.replaceOp(mergeOp, mergeInputs[0]);
62
63 return success();
64 }
65};
66
67struct DeconstructCMergePattern
68 : public OpRewritePattern<handshake::ControlMergeOp> {
69 using OpRewritePattern::OpRewritePattern;
70
71 LogicalResult matchAndRewrite(handshake::ControlMergeOp cmergeOp,
72 PatternRewriter &rewriter) const override {
73 if (cmergeOp.getNumOperands() <= 2)
74 return failure();
75
76 Type cmergeIndexType = cmergeOp.getIndex().getType();
77 auto loc = cmergeOp.getLoc();
78
79 // Function for create a cmerge-pack structure which generates a
80 // tuple<index, data> from two operands and an index offset.
81 auto mergeTwoOperands = [&](Value op0, Value op1,
82 unsigned idxOffset) -> Value {
83 auto cm2 = handshake::ControlMergeOp::create(
84 rewriter, loc, ValueRange{op0, op1}, cmergeIndexType);
85 Value idxOperand = cm2.getIndex();
86 if (idxOffset != 0) {
87 // Non-zero index offset; add it to the index operand.
88 idxOperand = arith::AddIOp::create(
89 rewriter, loc, idxOperand,
90 arith::ConstantOp::create(
91 rewriter, loc,
92 rewriter.getIntegerAttr(cmergeIndexType, idxOffset)));
93 }
94
95 // Pack index and data into a tuple s.t. they share control.
96 return handshake::PackOp::create(rewriter, loc,
97 ValueRange{cm2.getResult(), idxOperand});
98 };
99
100 llvm::SmallVector<Value> packedTuples;
101 // Perform the two-operand merges.
102 for (unsigned i = 0, e = cmergeOp.getNumOperands(); i < ((e / 2) * 2);
103 i += 2) {
104 packedTuples.push_back(mergeTwoOperands(cmergeOp.getOperand(i),
105 cmergeOp.getOperand(i + 1), i));
106 }
107 if (cmergeOp.getNumOperands() % 2 != 0) {
108 // If there is an odd number of operands, the last operand becomes a tuple
109 // of itself with an index of the number of operands - 1.
110 unsigned lastIdx = cmergeOp.getNumOperands() - 1;
111 packedTuples.push_back(handshake::PackOp::create(
112 rewriter, loc,
113 ValueRange{cmergeOp.getOperand(lastIdx),
114 arith::ConstantOp::create(
115 rewriter, loc,
116 rewriter.getIntegerAttr(cmergeIndexType, lastIdx))}));
117 }
118
119 // Non-deterministically merge the tuples and unpack the result.
120 auto mergedTuple =
121 handshake::MergeOp::create(rewriter, loc, ValueRange(packedTuples));
122
123 // And finally, replace the original cmerge with the unpacked result.
124 rewriter.replaceOpWithNewOp<handshake::UnpackOp>(cmergeOp,
125 mergedTuple.getResult());
126 return success();
127 }
128};
129
130struct HandshakeSplitMerges
131 : public circt::handshake::impl::HandshakeSplitMergesBase<
132 HandshakeSplitMerges> {
133 void runOnOperation() override {
134 RewritePatternSet patterns(&getContext());
135 patterns.insert<DeconstructCMergePattern, DeconstructMergePattern>(
136 &getContext());
137
138 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
139 signalPassFailure();
140 };
141};
142} // namespace
143
145 return std::make_unique<HandshakeSplitMerges>();
146}
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createHandshakeSplitMergesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.