15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Rewrite/FrozenRewritePatternSet.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23#define GEN_PASS_DEF_HANDSHAKESPLITMERGES
24#include "circt/Dialect/Handshake/HandshakePasses.h.inc"
34struct DeconstructMergePattern :
public OpRewritePattern<handshake::MergeOp> {
35 using OpRewritePattern::OpRewritePattern;
37 LogicalResult matchAndRewrite(handshake::MergeOp mergeOp,
38 PatternRewriter &rewriter)
const override {
39 if (mergeOp.getNumOperands() <= 2)
42 llvm::SmallVector<Value> mergeInputs;
43 llvm::copy(mergeOp.getOperands(), std::back_inserter(mergeInputs));
46 while (mergeInputs.size() > 1) {
47 llvm::SmallVector<Value> newMergeInputs;
48 for (
unsigned i = 0, e = mergeInputs.size(); i < ((e / 2) * 2); i += 2) {
49 auto cm2 = handshake::MergeOp::create(
50 rewriter, mergeOp.getLoc(),
51 ValueRange{mergeInputs[i], mergeInputs[i + 1]});
52 newMergeInputs.push_back(cm2.getResult());
54 if (mergeInputs.size() % 2 != 0)
55 newMergeInputs.push_back(mergeInputs.back());
57 mergeInputs = newMergeInputs;
60 assert(mergeInputs.size() == 1);
61 rewriter.replaceOp(mergeOp, mergeInputs[0]);
67struct DeconstructCMergePattern
69 using OpRewritePattern::OpRewritePattern;
71 LogicalResult matchAndRewrite(handshake::ControlMergeOp cmergeOp,
72 PatternRewriter &rewriter)
const override {
73 if (cmergeOp.getNumOperands() <= 2)
76 Type cmergeIndexType = cmergeOp.getIndex().getType();
77 auto loc = cmergeOp.getLoc();
81 auto mergeTwoOperands = [&](Value op0, Value op1,
82 unsigned idxOffset) -> Value {
83 auto cm2 = handshake::ControlMergeOp::create(
84 rewriter, loc, ValueRange{op0, op1}, cmergeIndexType);
85 Value idxOperand = cm2.getIndex();
88 idxOperand = arith::AddIOp::create(
89 rewriter, loc, idxOperand,
90 arith::ConstantOp::create(
92 rewriter.getIntegerAttr(cmergeIndexType, idxOffset)));
96 return handshake::PackOp::create(rewriter, loc,
97 ValueRange{cm2.getResult(), idxOperand});
100 llvm::SmallVector<Value> packedTuples;
102 for (
unsigned i = 0, e = cmergeOp.getNumOperands(); i < ((e / 2) * 2);
104 packedTuples.push_back(mergeTwoOperands(cmergeOp.getOperand(i),
105 cmergeOp.getOperand(i + 1), i));
107 if (cmergeOp.getNumOperands() % 2 != 0) {
110 unsigned lastIdx = cmergeOp.getNumOperands() - 1;
111 packedTuples.push_back(handshake::PackOp::create(
113 ValueRange{cmergeOp.getOperand(lastIdx),
114 arith::ConstantOp::create(
116 rewriter.getIntegerAttr(cmergeIndexType, lastIdx))}));
121 handshake::MergeOp::create(rewriter, loc, ValueRange(packedTuples));
124 rewriter.replaceOpWithNewOp<handshake::UnpackOp>(cmergeOp,
125 mergedTuple.getResult());
130struct HandshakeSplitMerges
131 :
public circt::handshake::impl::HandshakeSplitMergesBase<
132 HandshakeSplitMerges> {
133 void runOnOperation()
override {
134 RewritePatternSet
patterns(&getContext());
135 patterns.insert<DeconstructCMergePattern, DeconstructMergePattern>(
138 if (failed(applyPatternsGreedily(getOperation(), std::move(
patterns))))
145 return std::make_unique<HandshakeSplitMerges>();
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createHandshakeSplitMergesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.