15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #define GEN_PASS_DEF_HANDSHAKESPLITMERGES
24 #include "circt/Dialect/Handshake/HandshakePasses.h.inc"
28 using namespace circt;
29 using namespace handshake;
34 struct DeconstructMergePattern :
public OpRewritePattern<handshake::MergeOp> {
35 using OpRewritePattern::OpRewritePattern;
37 LogicalResult matchAndRewrite(handshake::MergeOp mergeOp,
38 PatternRewriter &rewriter)
const override {
39 if (mergeOp.getNumOperands() <= 2)
42 llvm::SmallVector<Value> mergeInputs;
43 llvm::copy(mergeOp.getOperands(), std::back_inserter(mergeInputs));
46 while (mergeInputs.size() > 1) {
47 llvm::SmallVector<Value> newMergeInputs;
48 for (
unsigned i = 0, e = mergeInputs.size(); i < ((e / 2) * 2); i += 2) {
49 auto cm2 = rewriter.create<handshake::MergeOp>(
50 mergeOp.getLoc(), ValueRange{mergeInputs[i], mergeInputs[i + 1]});
51 newMergeInputs.push_back(cm2.getResult());
53 if (mergeInputs.size() % 2 != 0)
54 newMergeInputs.push_back(mergeInputs.back());
56 mergeInputs = newMergeInputs;
59 assert(mergeInputs.size() == 1);
60 rewriter.replaceOp(mergeOp, mergeInputs[0]);
66 struct DeconstructCMergePattern
68 using OpRewritePattern::OpRewritePattern;
70 LogicalResult matchAndRewrite(handshake::ControlMergeOp cmergeOp,
71 PatternRewriter &rewriter)
const override {
72 if (cmergeOp.getNumOperands() <= 2)
75 Type cmergeIndexType = cmergeOp.getIndex().getType();
76 auto loc = cmergeOp.getLoc();
80 auto mergeTwoOperands = [&](Value op0, Value op1,
81 unsigned idxOffset) -> Value {
82 auto cm2 = rewriter.create<handshake::ControlMergeOp>(
83 loc, ValueRange{op0, op1}, cmergeIndexType);
84 Value idxOperand = cm2.getIndex();
87 idxOperand = rewriter.create<arith::AddIOp>(
89 rewriter.create<arith::ConstantOp>(
90 loc, rewriter.getIntegerAttr(cmergeIndexType, idxOffset)));
94 return rewriter.create<handshake::PackOp>(
95 loc, ValueRange{cm2.getResult(), idxOperand});
98 llvm::SmallVector<Value> packedTuples;
100 for (
unsigned i = 0, e = cmergeOp.getNumOperands(); i < ((e / 2) * 2);
102 packedTuples.push_back(mergeTwoOperands(cmergeOp.getOperand(i),
103 cmergeOp.getOperand(i + 1), i));
105 if (cmergeOp.getNumOperands() % 2 != 0) {
108 unsigned lastIdx = cmergeOp.getNumOperands() - 1;
109 packedTuples.push_back(rewriter.create<handshake::PackOp>(
110 loc, ValueRange{cmergeOp.getOperand(lastIdx),
111 rewriter.create<arith::ConstantOp>(
112 loc, rewriter.getIntegerAttr(cmergeIndexType,
118 rewriter.create<handshake::MergeOp>(loc, ValueRange(packedTuples));
121 rewriter.replaceOpWithNewOp<handshake::UnpackOp>(cmergeOp,
122 mergedTuple.getResult());
127 struct HandshakeSplitMerges
128 :
public circt::handshake::impl::HandshakeSplitMergesBase<
129 HandshakeSplitMerges> {
130 void runOnOperation()
override {
131 RewritePatternSet
patterns(&getContext());
132 patterns.insert<DeconstructCMergePattern, DeconstructMergePattern>(
136 applyPatternsAndFoldGreedily(getOperation(), std::move(
patterns))))
143 return std::make_unique<HandshakeSplitMerges>();
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createHandshakeSplitMergesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.