16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
19 #include "mlir/Transforms/DialectConversion.h"
23 #define GEN_PASS_DEF_HANDSHAKEREMOVEBUFFERS
24 #define GEN_PASS_DEF_HANDSHAKEINSERTBUFFERS
25 #include "circt/Dialect/Handshake/HandshakePasses.h.inc"
29 using namespace circt;
30 using namespace handshake;
35 struct RemoveHandshakeBuffers :
public OpRewritePattern<handshake::BufferOp> {
36 using OpRewritePattern::OpRewritePattern;
38 LogicalResult matchAndRewrite(handshake::BufferOp bufferOp,
39 PatternRewriter &rewriter)
const override {
40 rewriter.replaceOp(bufferOp, bufferOp.getOperand());
45 struct HandshakeRemoveBuffersPass
46 :
public circt::handshake::impl::HandshakeRemoveBuffersBase<
47 HandshakeRemoveBuffersPass> {
48 void runOnOperation()
override {
49 handshake::FuncOp op = getOperation();
50 ConversionTarget target(getContext());
51 target.addIllegalOp<handshake::BufferOp>();
52 RewritePatternSet
patterns(&getContext());
53 patterns.insert<RemoveHandshakeBuffers>(&getContext());
55 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
64 return arg.getType().isIntOrFloat() || isa<NoneType>(arg.getType());
68 return !isa_and_nonnull<BufferOp>(definingOp) && !isa<BufferOp>(usingOp);
71 static void insertBuffer(Location loc, Value operand, OpBuilder &builder,
72 unsigned numSlots, BufferTypeEnum bufferType) {
73 auto ip = builder.saveInsertionPoint();
74 builder.setInsertionPointAfterValue(operand);
76 builder.create<handshake::BufferOp>(loc, operand, numSlots, bufferType);
77 operand.replaceUsesWithIf(
78 bufferOp, function_ref<
bool(OpOperand &)>([](OpOperand &operand) ->
bool {
79 return !isa<handshake::BufferOp>(operand.getOwner());
81 builder.restoreInsertionPoint(ip);
85 static void bufferResults(OpBuilder &builder, Operation *op,
unsigned numSlots,
86 BufferTypeEnum bufferType) {
87 for (
auto res : op->getResults()) {
88 Operation *user = *res.getUsers().begin();
89 if (isa<handshake::BufferOp>(user))
91 insertBuffer(op->getLoc(), res, builder, numSlots, bufferType);
97 BufferTypeEnum bufferType = BufferTypeEnum::seq) {
99 for (
auto &arg : r.getArguments()) {
102 insertBuffer(arg.getLoc(), arg, builder, numSlots, bufferType);
105 for (
auto &defOp : r.getOps()) {
106 for (
auto res : defOp.getResults()) {
107 for (
auto *useOp : res.getUsers()) {
110 insertBuffer(res.getLoc(), res, builder, numSlots, bufferType);
119 llvm::function_ref<
bool(Operation *)> breaksCycle) {
120 SetVector<Operation *> visited;
121 SmallVector<Operation *> stack = {src};
123 while (!stack.empty()) {
124 Operation *curr = stack.pop_back_val();
126 if (visited.contains(curr))
128 visited.insert(curr);
130 if (breaksCycle(curr))
133 for (
auto *user : curr->getUsers()) {
138 stack.push_back(user);
147 BufferTypeEnum = BufferTypeEnum::seq) {
153 auto isSeqBuffer = [](
auto op) {
154 auto bufferOp = dyn_cast<handshake::BufferOp>(op);
155 return bufferOp && bufferOp.isSequential();
158 for (
auto mergeOp : r.getOps<MergeLikeOpInterface>()) {
162 bool sequential =
inCycle(mergeOp, isSeqBuffer);
164 sequential ? BufferTypeEnum::seq : BufferTypeEnum::fifo);
175 BufferTypeEnum::seq);
178 BufferTypeEnum::fifo);
182 StringRef
strategy,
unsigned bufferSize) {
190 return r.getParentOp()->emitOpError()
191 <<
"Unknown buffer strategy: " <<
strategy;
197 struct HandshakeInsertBuffersPass
198 :
public circt::handshake::impl::HandshakeInsertBuffersBase<
199 HandshakeInsertBuffersPass> {
200 HandshakeInsertBuffersPass(
const std::string &
strategy,
unsigned bufferSize) {
202 this->bufferSize = bufferSize;
205 void runOnOperation()
override {
206 auto f = getOperation();
210 OpBuilder builder(f.getContext());
219 std::unique_ptr<mlir::Pass>
221 return std::make_unique<HandshakeRemoveBuffersPass>();
224 std::unique_ptr<mlir::OperationPass<handshake::FuncOp>>
226 unsigned bufferSize) {
227 return std::make_unique<HandshakeInsertBuffersPass>(
strategy, bufferSize);
static void bufferResults(OpBuilder &builder, Operation *op, unsigned numSlots, BufferTypeEnum bufferType)
static void bufferAllFIFOStrategy(Region &r, OpBuilder &builder, unsigned numSlots)
static bool isUnbufferedChannel(Operation *definingOp, Operation *usingOp)
static void insertBuffer(Location loc, Value operand, OpBuilder &builder, unsigned numSlots, BufferTypeEnum bufferType)
static void bufferAllStrategy(Region &r, OpBuilder &builder, unsigned numSlots, BufferTypeEnum bufferType=BufferTypeEnum::seq)
static LogicalResult bufferRegion(Region &r, OpBuilder &builder, StringRef strategy, unsigned bufferSize)
static void bufferCyclesStrategy(Region &r, OpBuilder &builder, unsigned numSlots, BufferTypeEnum=BufferTypeEnum::seq)
static bool shouldBufferArgument(BlockArgument arg)
static bool inCycle(Operation *src, llvm::function_ref< bool(Operation *)> breaksCycle)
std::unique_ptr< mlir::OperationPass< handshake::FuncOp > > createHandshakeInsertBuffersPass(const std::string &strategy="all", unsigned bufferSize=2)
std::unique_ptr< mlir::Pass > createHandshakeRemoveBuffersPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.