16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
18 #include "mlir/Transforms/DialectConversion.h"
20 using namespace circt;
21 using namespace handshake;
26 struct RemoveHandshakeBuffers :
public OpRewritePattern<handshake::BufferOp> {
27 using OpRewritePattern::OpRewritePattern;
29 LogicalResult matchAndRewrite(handshake::BufferOp bufferOp,
30 PatternRewriter &rewriter)
const override {
31 rewriter.replaceOp(bufferOp, bufferOp.getOperand());
36 struct HandshakeRemoveBuffersPass
37 :
public HandshakeRemoveBuffersBase<HandshakeRemoveBuffersPass> {
38 void runOnOperation()
override {
39 handshake::FuncOp op = getOperation();
40 ConversionTarget target(getContext());
41 target.addIllegalOp<handshake::BufferOp>();
42 RewritePatternSet
patterns(&getContext());
43 patterns.insert<RemoveHandshakeBuffers>(&getContext());
45 if (failed(applyPartialConversion(op, target, std::move(
patterns))))
54 return arg.getType().isIntOrFloat() || arg.getType().isa<NoneType>();
58 return !isa_and_nonnull<BufferOp>(definingOp) && !isa<BufferOp>(usingOp);
62 unsigned numSlots, BufferTypeEnum bufferType) {
63 auto ip =
builder.saveInsertionPoint();
64 builder.setInsertionPointAfterValue(operand);
66 builder.create<handshake::BufferOp>(loc, operand, numSlots, bufferType);
67 operand.replaceUsesWithIf(
68 bufferOp, function_ref<
bool(OpOperand &)>([](OpOperand &operand) ->
bool {
69 return !isa<handshake::BufferOp>(operand.getOwner());
71 builder.restoreInsertionPoint(ip);
76 BufferTypeEnum bufferType) {
77 for (
auto res : op->getResults()) {
78 Operation *user = *res.getUsers().begin();
79 if (isa<handshake::BufferOp>(user))
87 BufferTypeEnum bufferType = BufferTypeEnum::seq) {
89 for (
auto &arg : r.getArguments()) {
95 for (
auto &defOp : r.getOps()) {
96 for (
auto res : defOp.getResults()) {
97 for (
auto *useOp : res.getUsers()) {
109 llvm::function_ref<
bool(Operation *)> breaksCycle) {
110 SetVector<Operation *> visited;
111 SmallVector<Operation *> stack = {src};
113 while (!stack.empty()) {
114 Operation *curr = stack.pop_back_val();
116 if (visited.contains(curr))
118 visited.insert(curr);
120 if (breaksCycle(curr))
123 for (
auto *user : curr->getUsers()) {
128 stack.push_back(user);
137 BufferTypeEnum = BufferTypeEnum::seq) {
143 auto isSeqBuffer = [](
auto op) {
144 auto bufferOp = dyn_cast<handshake::BufferOp>(op);
145 return bufferOp && bufferOp.isSequential();
148 for (
auto mergeOp : r.getOps<MergeLikeOpInterface>()) {
152 bool sequential =
inCycle(mergeOp, isSeqBuffer);
154 sequential ? BufferTypeEnum::seq : BufferTypeEnum::fifo);
165 BufferTypeEnum::seq);
168 BufferTypeEnum::fifo);
173 unsigned bufferSize) {
174 if (strategy ==
"cycles")
176 else if (strategy ==
"all")
178 else if (strategy ==
"allFIFO")
181 return r.getParentOp()->emitOpError()
182 <<
"Unknown buffer strategy: " << strategy;
188 struct HandshakeInsertBuffersPass
189 :
public HandshakeInsertBuffersBase<HandshakeInsertBuffersPass> {
190 HandshakeInsertBuffersPass(
const std::string &strategy,
unsigned bufferSize) {
191 this->strategy = strategy;
192 this->bufferSize = bufferSize;
195 void runOnOperation()
override {
196 auto f = getOperation();
200 OpBuilder
builder(f.getContext());
209 std::unique_ptr<mlir::Pass>
211 return std::make_unique<HandshakeRemoveBuffersPass>();
214 std::unique_ptr<mlir::OperationPass<handshake::FuncOp>>
216 unsigned bufferSize) {
217 return std::make_unique<HandshakeInsertBuffersPass>(strategy, bufferSize);
static void bufferResults(OpBuilder &builder, Operation *op, unsigned numSlots, BufferTypeEnum bufferType)
static void bufferAllFIFOStrategy(Region &r, OpBuilder &builder, unsigned numSlots)
static bool isUnbufferedChannel(Operation *definingOp, Operation *usingOp)
static void insertBuffer(Location loc, Value operand, OpBuilder &builder, unsigned numSlots, BufferTypeEnum bufferType)
static void bufferAllStrategy(Region &r, OpBuilder &builder, unsigned numSlots, BufferTypeEnum bufferType=BufferTypeEnum::seq)
static void bufferCyclesStrategy(Region &r, OpBuilder &builder, unsigned numSlots, BufferTypeEnum=BufferTypeEnum::seq)
static bool shouldBufferArgument(BlockArgument arg)
static bool inCycle(Operation *src, llvm::function_ref< bool(Operation *)> breaksCycle)
std::unique_ptr< mlir::OperationPass< handshake::FuncOp > > createHandshakeInsertBuffersPass(const std::string &strategy="all", unsigned bufferSize=2)
std::unique_ptr< mlir::Pass > createHandshakeRemoveBuffersPass()
LogicalResult bufferRegion(Region &r, OpBuilder &rewriter, StringRef strategy, unsigned bufferSize)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...