CIRCT 22.0.0git
Loading...
Searching...
No Matches
RpcClient.cpp
Go to the documentation of this file.
1//===- RpcClient.cpp - ESI Cosim RPC client implementation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// DO NOT EDIT!
10// This file is distributed as part of an ESI package. The source for this file
11// should always be modified within CIRCT
12// (lib/dialect/ESI/runtime/cpp/lib/backends/RpcClient.cpp).
13//
14//===----------------------------------------------------------------------===//
15
17#include "esi/Utils.h"
18
19#include "cosim.grpc.pb.h"
20
21#include <grpc/grpc.h>
22#include <grpcpp/channel.h>
23#include <grpcpp/client_context.h>
24#include <grpcpp/create_channel.h>
25#include <grpcpp/security/credentials.h>
26
27#include <condition_variable>
28#include <mutex>
29#include <thread>
30
31using namespace esi;
32using namespace esi::backends::cosim;
33
34using grpc::ClientContext;
35using grpc::Status;
36
37static void checkStatus(Status s, const std::string &msg) {
38 if (!s.ok())
39 throw std::runtime_error(msg + ". Code " + to_string(s.error_code()) +
40 ": " + s.error_message() + " (" +
41 s.error_details() + ")");
42}
43
44//===----------------------------------------------------------------------===//
45// ReadChannelConnectionImpl - gRPC streaming reader implementation
46//===----------------------------------------------------------------------===//
47
48namespace {
49class ReadChannelConnectionImpl
51 public grpc::ClientReadReactor<::esi::cosim::Message> {
52public:
53 ReadChannelConnectionImpl(::esi::cosim::ChannelServer::Stub *stub,
54 const ::esi::cosim::ChannelDesc &desc,
56 : stub(stub), grpcDesc(desc), callback(std::move(callback)),
57 context(nullptr), done(false) {}
58
59 ~ReadChannelConnectionImpl() override { disconnect(); }
60
61 void start() {
62 context = new ClientContext();
63 stub->async()->ConnectToClientChannel(context, &grpcDesc, this);
64 StartCall();
65 StartRead(&incomingMessage);
66 }
67
68 // Utility to check if we are disconnecting. If so, mark done and notify.
69 bool isDisconnecting() {
70 std::lock_guard<std::mutex> lock(doneMutex);
71 if (disconnecting) {
72 done = true;
73 doneCV.notify_all();
74 return true;
75 }
76 return false;
77 }
78
79 void OnReadDone(bool ok) override {
80 if (!ok)
81 // This happens when we are disconnecting since we are canceling the call.
82 return;
83
84 // Check if we're disconnecting before processing.
85 if (isDisconnecting())
86 return;
87
88 // Read the delivered message and push it onto the queue.
89 const std::string &messageString = incomingMessage.data();
90 MessageData data(reinterpret_cast<const uint8_t *>(messageString.data()),
91 messageString.size());
92
93 // Process the callback. Check disconnecting to avoid blocking forever.
94 while (!callback(data)) {
95 if (isDisconnecting())
96 return;
97 std::this_thread::sleep_for(std::chrono::milliseconds(10));
98 }
99
100 // Check again before starting new read.
101 if (isDisconnecting())
102 return;
103
104 // Initiate the next read.
105 StartRead(&incomingMessage);
106 }
107
108 // Called by gRPC when the RPC is fully complete (after cancel or error).
109 void OnDone(const grpc::Status & /*status*/) override {
110 std::lock_guard<std::mutex> lock(doneMutex);
111 done = true;
112 doneCV.notify_all();
113 }
114
115 void disconnect() override {
116 {
117 std::lock_guard<std::mutex> lock(doneMutex);
118
119 // Mark disconnecting first to prevent OnReadDone from starting new reads.
120 if (disconnecting)
121 return;
122 disconnecting = true;
123
124 if (!context)
125 return;
126
127 // If already done, just clean up.
128 if (done) {
129 delete context;
130 context = nullptr;
131 return;
132 }
133 }
134
135 // Try to cancel the RPC.
136 context->TryCancel();
137
138 // Wait briefly for OnDone. Use timeout as TryCancel may not immediately
139 // interrupt pending reads.
140 std::unique_lock<std::mutex> lock(doneMutex);
141 doneCV.wait_for(lock, std::chrono::milliseconds(1000),
142 [this]() { return done; });
143
144 delete context;
145 context = nullptr;
146 }
147
148private:
149 ::esi::cosim::ChannelServer::Stub *stub;
150 ::esi::cosim::ChannelDesc grpcDesc;
152 ClientContext *context;
153 ::esi::cosim::Message incomingMessage;
154
155 // Synchronization for waiting on gRPC completion.
156 std::mutex doneMutex;
157 std::condition_variable doneCV;
158 bool done;
159 bool disconnecting = false;
160};
161} // namespace
162
163//===----------------------------------------------------------------------===//
164// RpcClient::Impl - internal implementation class
165//===----------------------------------------------------------------------===//
166
168public:
169 Impl(const std::string &hostname, uint16_t port) {
170 auto channel = grpc::CreateChannel(hostname + ":" + std::to_string(port),
171 grpc::InsecureChannelCredentials());
172 stub = ::esi::cosim::ChannelServer::NewStub(channel);
173 }
174
175 ::esi::cosim::ChannelServer::Stub *getStub() const { return stub.get(); }
176
177 ::esi::cosim::Manifest getManifest() const {
178 ::esi::cosim::Manifest response;
179 // To get around the a race condition where the manifest may not be set yet,
180 // loop until it is. TODO: fix this with the DPI API change.
181 do {
182 ClientContext context;
183 ::esi::cosim::VoidMessage arg;
184 Status s = stub->GetManifest(&context, arg, &response);
185 checkStatus(s, "Failed to get manifest");
186 std::this_thread::sleep_for(std::chrono::milliseconds(10));
187 } while (response.esi_version() < 0);
188 return response;
189 }
190
191 bool getChannelDesc(const std::string &channelName,
192 ::esi::cosim::ChannelDesc &desc) const {
193 ClientContext context;
194 ::esi::cosim::VoidMessage arg;
195 ::esi::cosim::ListOfChannels response;
196 Status s = stub->ListChannels(&context, arg, &response);
197 checkStatus(s, "Failed to list channels");
198 for (const auto &channel : response.channels())
199 if (channel.name() == channelName) {
200 desc = channel;
201 return true;
202 }
203 return false;
204 }
205
206 std::vector<RpcClient::ChannelDesc> listChannels() const {
207 ClientContext context;
208 ::esi::cosim::VoidMessage arg;
209 ::esi::cosim::ListOfChannels response;
210 Status s = stub->ListChannels(&context, arg, &response);
211 checkStatus(s, "Failed to list channels");
212
213 std::vector<RpcClient::ChannelDesc> result;
214 result.reserve(response.channels_size());
215 for (const auto &grpcDesc : response.channels()) {
217 desc.name = grpcDesc.name();
218 desc.type = grpcDesc.type();
219 if (grpcDesc.dir() ==
220 ::esi::cosim::ChannelDesc::Direction::ChannelDesc_Direction_TO_SERVER)
222 else
224 result.push_back(std::move(desc));
225 }
226 return result;
227 }
228
229 void writeToServer(const std::string &channelName, const MessageData &data) {
230 ClientContext context;
231 ::esi::cosim::AddressedMessage grpcMsg;
232 grpcMsg.set_channel_name(channelName);
233 grpcMsg.mutable_message()->set_data(data.getBytes(), data.getSize());
234 ::esi::cosim::VoidMessage response;
235 grpc::Status sendStatus = stub->SendToServer(&context, grpcMsg, &response);
236 if (!sendStatus.ok())
237 throw std::runtime_error("Failed to write to channel '" + channelName +
238 "': " + std::to_string(sendStatus.error_code()) +
239 " " + sendStatus.error_message() +
240 ". Details: " + sendStatus.error_details());
241 }
242
243 std::unique_ptr<RpcClient::ReadChannelConnection>
244 connectClientReceiver(const std::string &channelName,
245 RpcClient::ReadCallback callback) {
246 ::esi::cosim::ChannelDesc grpcDesc;
247 if (!getChannelDesc(channelName, grpcDesc))
248 throw std::runtime_error("Could not find channel '" + channelName + "'");
249
250 auto connection = std::make_unique<ReadChannelConnectionImpl>(
251 stub.get(), grpcDesc, std::move(callback));
252 connection->start();
253 return connection;
254 }
255
256private:
257 std::unique_ptr<::esi::cosim::ChannelServer::Stub> stub;
258};
259
260//===----------------------------------------------------------------------===//
261// RpcClient
262//===----------------------------------------------------------------------===//
263
264RpcClient::RpcClient(const std::string &hostname, uint16_t port)
265 : impl(std::make_unique<Impl>(hostname, port)) {}
266
267RpcClient::~RpcClient() = default;
268
269uint32_t RpcClient::getEsiVersion() const {
270 return impl->getManifest().esi_version();
271}
272
273std::vector<uint8_t> RpcClient::getCompressedManifest() const {
274 ::esi::cosim::Manifest response = impl->getManifest();
275 std::string compressedManifestStr = response.compressed_manifest();
276 return std::vector<uint8_t>(compressedManifestStr.begin(),
277 compressedManifestStr.end());
278}
279
280bool RpcClient::getChannelDesc(const std::string &channelName,
281 ChannelDesc &desc) const {
282 ::esi::cosim::ChannelDesc grpcDesc;
283 if (!impl->getChannelDesc(channelName, grpcDesc))
284 return false;
285
286 desc.name = grpcDesc.name();
287 desc.type = grpcDesc.type();
288 if (grpcDesc.dir() ==
289 ::esi::cosim::ChannelDesc::Direction::ChannelDesc_Direction_TO_SERVER)
291 else
293 return true;
294}
295
296std::vector<RpcClient::ChannelDesc> RpcClient::listChannels() const {
297 return impl->listChannels();
298}
299
300void RpcClient::writeToServer(const std::string &channelName,
301 const MessageData &data) {
302 impl->writeToServer(channelName, data);
303}
304
305std::unique_ptr<RpcClient::ReadChannelConnection>
306RpcClient::connectClientReceiver(const std::string &channelName,
307 ReadCallback callback) {
308 return impl->connectClientReceiver(channelName, std::move(callback));
309}
static std::unique_ptr< Context > context
static void checkStatus(Status s, const std::string &msg)
Definition RpcClient.cpp:37
std::unique_ptr<::esi::cosim::ChannelServer::Stub > stub
Impl(const std::string &hostname, uint16_t port)
void writeToServer(const std::string &channelName, const MessageData &data)
::esi::cosim::Manifest getManifest() const
std::vector< RpcClient::ChannelDesc > listChannels() const
bool getChannelDesc(const std::string &channelName, ::esi::cosim::ChannelDesc &desc) const
::esi::cosim::ChannelServer::Stub * getStub() const
std::unique_ptr< RpcClient::ReadChannelConnection > connectClientReceiver(const std::string &channelName, RpcClient::ReadCallback callback)
A logical chunk of data representing serialized data.
Definition Common.h:113
Abstract handle for a read channel connection.
Definition RpcClient.h:77
std::unique_ptr< ReadChannelConnection > connectClientReceiver(const std::string &channelName, ReadCallback callback)
Connect to a client-bound channel and receive messages via callback.
RpcClient(const std::string &hostname, uint16_t port)
std::vector< uint8_t > getCompressedManifest() const
Get the compressed manifest from the server.
uint32_t getEsiVersion() const
Get the ESI version from the manifest.
void writeToServer(const std::string &channelName, const MessageData &data)
Send a message to a server-bound channel.
bool getChannelDesc(const std::string &channelName, ChannelDesc &desc) const
Get the channel description for a channel name.
std::function< bool(const MessageData &)> ReadCallback
Callback type for receiving messages from a client-bound channel.
Definition RpcClient.h:73
std::vector< ChannelDesc > listChannels() const
List all channels available on the server.
std::unique_ptr< Impl > impl
Definition RpcClient.h:90
Definition esi.py:1
Description of a channel from the server.
Definition RpcClient.h:55