CIRCT 22.0.0git
Loading...
Searching...
No Matches
RpcServer.cpp
Go to the documentation of this file.
1//===- RpcServer.cpp - Run a cosim server ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "esi/Context.h"
11#include "esi/Utils.h"
12
13#include "cosim.grpc.pb.h"
14
15#include <grpc/grpc.h>
16#include <grpcpp/security/server_credentials.h>
17#include <grpcpp/server.h>
18#include <grpcpp/server_builder.h>
19#include <grpcpp/server_context.h>
20
21#include <algorithm>
22#include <cassert>
23#include <cstdlib>
24
25using namespace esi;
26using namespace esi::cosim;
27
28using grpc::CallbackServerContext;
29using grpc::Server;
30using grpc::ServerUnaryReactor;
31using grpc::ServerWriteReactor;
32using grpc::Status;
33using grpc::StatusCode;
34
35/// Write the port number to a file. Necessary when we are allowed to select our
36/// own port. We can't use stdout/stderr because the flushing semantics are
37/// undefined (as in `flush()` doesn't work on all simulators).
38static void writePort(uint16_t port) {
39 // "cosim.cfg" since we may want to include other info in the future.
40 FILE *fd = fopen("cosim.cfg", "w");
41 fprintf(fd, "port: %u\n", static_cast<unsigned int>(port));
42 fclose(fd);
43}
44
45namespace {
46class RpcServerReadPort;
47class RpcServerWritePort;
48} // namespace
49
52public:
53 Impl(Context &ctxt, int port);
54 ~Impl();
55
56 Context &getContext() { return ctxt; }
57
58 //===--------------------------------------------------------------------===//
59 // Internal API
60 //===--------------------------------------------------------------------===//
61
63 const std::vector<uint8_t> &compressedManifest) {
64 this->compressedManifest = compressedManifest;
65 this->esiVersion = esiVersion;
66 }
67
68 ReadChannelPort &registerReadPort(const std::string &name,
69 const std::string &type);
70 WriteChannelPort &registerWritePort(const std::string &name,
71 const std::string &type);
72
73 void stop(uint32_t timeoutMS = 0);
74
75 int getPort() { return port; }
76
77 //===--------------------------------------------------------------------===//
78 // RPC API implementations. See the .proto file for the API documentation.
79 //===--------------------------------------------------------------------===//
80
81 ServerUnaryReactor *GetManifest(CallbackServerContext *context,
82 const VoidMessage *,
83 Manifest *response) override;
84 ServerUnaryReactor *ListChannels(CallbackServerContext *, const VoidMessage *,
85 ListOfChannels *channelsOut) override;
86 ServerWriteReactor<esi::cosim::Message> *
87 ConnectToClientChannel(CallbackServerContext *context,
88 const ChannelDesc *request) override;
89 ServerUnaryReactor *SendToServer(CallbackServerContext *context,
90 const esi::cosim::AddressedMessage *request,
91 esi::cosim::VoidMessage *response) override;
92
93private:
96 std::vector<uint8_t> compressedManifest;
97 std::map<std::string, std::unique_ptr<RpcServerReadPort>> readPorts;
98 std::map<std::string, std::unique_ptr<RpcServerWritePort>> writePorts;
99 int port = -1;
100 std::unique_ptr<Server> server;
101};
103
104//===----------------------------------------------------------------------===//
105// Read and write ports
106//
107// Implemented as simple queues which the RPC server writes to and reads from.
108//===----------------------------------------------------------------------===//
109
110namespace {
111/// Implements a simple read queue. The RPC server will push messages into this
112/// as appropriate.
113class RpcServerReadPort : public ReadChannelPort {
114public:
115 RpcServerReadPort(Type *type) : ReadChannelPort(type) {}
116
117 /// Internal call. Push a message FROM the RPC client to the read port.
118 void push(MessageData &data) {
119 while (!callback(data))
120 std::this_thread::sleep_for(std::chrono::milliseconds(1));
121 }
122};
123
124/// Implements a simple write queue. The RPC server will pull messages from this
125/// as appropriate. Note that this could be more performant if a callback is
126/// used. This would have more complexity as when a client disconnects the
127/// outstanding messages will need somewhere to be held until the next client
128/// connects. For now, it's simpler to just have the server poll the queue.
129class RpcServerWritePort : public WriteChannelPort {
130public:
131 RpcServerWritePort(Type *type) : WriteChannelPort(type) {}
132
134
135protected:
136 void writeImpl(const MessageData &data) override { writeQueue.push(data); }
137 bool tryWriteImpl(const MessageData &data) override {
138 writeQueue.push(data);
139 return true;
140 }
141};
142} // namespace
143
144//===----------------------------------------------------------------------===//
145// RPC server implementations
146//===----------------------------------------------------------------------===//
147
148/// Start a server on the given port. -1 means to let the OS pick a port.
149Impl::Impl(Context &ctxt, int port) : ctxt(ctxt), esiVersion(-1) {
150 grpc::ServerBuilder builder;
151 std::string server_address("127.0.0.1:" + std::to_string(port));
152 // TODO: use secure credentials. Not so bad for now since we only accept
153 // connections on localhost.
154 builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(),
155 &port);
156 builder.RegisterService(this);
157 server = builder.BuildAndStart();
158 if (!server)
159 throw std::runtime_error("Failed to start server on " + server_address);
161 this->port = port;
162 ctxt.getLogger().info("cosim", "Server listening on 127.0.0.1:" +
163 std::to_string(port));
164}
165
166void Impl::stop(uint32_t timeoutMS) {
167 // Disconnect all the ports.
168 for (auto &[name, port] : readPorts)
169 port->disconnect();
170 for (auto &[name, port] : writePorts)
171 port->disconnect();
172
173 // Shutdown the server and wait for it to finish.
174 if (timeoutMS > 0)
175 server->Shutdown(gpr_time_add(
176 gpr_now(GPR_CLOCK_REALTIME),
177 gpr_time_from_millis(static_cast<int>(timeoutMS), GPR_TIMESPAN)));
178 else
179 server->Shutdown();
180
181 server->Wait();
182 server = nullptr;
183}
184
186 if (server)
187 stop();
188}
189
191 const std::string &type) {
192 auto port = new RpcServerReadPort(new Type(type));
193 readPorts.emplace(name, port);
194 return *port;
195}
197 const std::string &type) {
198 auto port = new RpcServerWritePort(new Type(type));
199 writePorts.emplace(name, port);
200 return *port;
201}
202
203ServerUnaryReactor *Impl::GetManifest(CallbackServerContext *context,
204 const VoidMessage *, Manifest *response) {
205 response->set_esi_version(esiVersion);
206 response->set_compressed_manifest(compressedManifest.data(),
207 compressedManifest.size());
208 ServerUnaryReactor *reactor = context->DefaultReactor();
209 reactor->Finish(Status::OK);
210 return reactor;
211}
212
213/// Load the list of channels into the response and fire it off.
214ServerUnaryReactor *Impl::ListChannels(CallbackServerContext *context,
215 const VoidMessage *,
216 ListOfChannels *channelsOut) {
217 for (auto &[name, port] : readPorts) {
218 auto *channel = channelsOut->add_channels();
219 channel->set_name(name);
220 channel->set_type(port->getType()->getID());
221 channel->set_dir(ChannelDesc::Direction::ChannelDesc_Direction_TO_SERVER);
222 }
223 for (auto &[name, port] : writePorts) {
224 auto *channel = channelsOut->add_channels();
225 channel->set_name(name);
226 channel->set_type(port->getType()->getID());
227 channel->set_dir(ChannelDesc::Direction::ChannelDesc_Direction_TO_CLIENT);
228 }
229
230 // The default reactor is basically to just finish the RPC call as if we're
231 // implementing the RPC function as a blocking call.
232 auto reactor = context->DefaultReactor();
233 reactor->Finish(Status::OK);
234 return reactor;
235}
236
237namespace {
238/// When a client connects to a read port (on its end, a write port on this
239/// end), construct one of these to poll the corresponding write port on this
240/// side and forward the messages.
241class RpcServerWriteReactor : public ServerWriteReactor<esi::cosim::Message> {
242public:
243 RpcServerWriteReactor(RpcServerWritePort *writePort)
244 : writePort(writePort), sentSuccessfully(SendStatus::UnknownStatus),
245 shutdown(false), onDoneCalled(false) {
246 myThread = std::thread(&RpcServerWriteReactor::threadLoop, this);
247 }
248
249 // gRPC manages the lifecycle of this object. OnDone() is called when gRPC is
250 // completely done with this reactor. We must wait for our thread to finish
251 // before deleting. See:
252 // https://github.com/grpc/grpc/blob/4795c5e69b25e8c767b498bea784da0ef8c96fd5/examples/cpp/route_guide/route_guide_callback_server.cc#L120
253 void OnDone() override {
254 // Signal shutdown and wake up any waiting threads.
255 {
256 std::scoped_lock<std::mutex> lock(sentMutex);
257 shutdown = true;
258 onDoneCalled = true;
259 }
260 sentSuccessfullyCV.notify_one();
261 onDoneCV.notify_one();
262
263 // Wait for the thread to finish before self-deleting.
264 if (myThread.joinable())
265 myThread.join();
266
267 delete this;
268 }
269
270 void OnWriteDone(bool ok) override {
271 std::scoped_lock<std::mutex> lock(sentMutex);
272 sentSuccessfully = ok ? SendStatus::Success : SendStatus::Failure;
273 sentSuccessfullyCV.notify_one();
274 }
275
276 void OnCancel() override {
277 std::scoped_lock<std::mutex> lock(sentMutex);
278 shutdown = true;
279 sentSuccessfully = SendStatus::Disconnect;
280 sentSuccessfullyCV.notify_one();
281 }
282
283private:
284 /// The polling loop.
285 void threadLoop();
286 /// The polling thread.
287 std::thread myThread;
288
289 /// Assoicated write port on this side. (Read port on the client side.)
290 RpcServerWritePort *writePort;
291
292 /// Mutex to protect the sentSuccessfully flag and shutdown state.
293 std::mutex sentMutex;
294 enum SendStatus { UnknownStatus, Success, Failure, Disconnect };
295 volatile SendStatus sentSuccessfully;
296 std::condition_variable sentSuccessfullyCV;
297
298 std::atomic<bool> shutdown;
299
300 /// Condition variable to wait for OnDone to be called.
301 bool onDoneCalled;
302 std::condition_variable onDoneCV;
303};
304
305} // namespace
306
307void RpcServerWriteReactor::threadLoop() {
308 while (!shutdown && sentSuccessfully != SendStatus::Disconnect) {
309 // TODO: adapt this to a new notification mechanism which is forthcoming.
310 if (!writePort || writePort->writeQueue.empty()) {
311 std::this_thread::sleep_for(std::chrono::microseconds(100));
312 continue;
313 }
314
315 // This lambda will get called with the message at the front of the queue.
316 // If the send is successful, return true to pop it. We don't know, however,
317 // if the message was sent successfully in this thread. It's only when the
318 // `OnWriteDone` method is called by gRPC that we know. Use locking and
319 // condition variables to orchestrate this confirmation.
320 writePort->writeQueue.pop([this](const MessageData &data) -> bool {
321 if (shutdown)
322 return false;
323
324 esi::cosim::Message msg;
325 msg.set_data(reinterpret_cast<const char *>(data.getBytes()),
326 data.getSize());
327
328 // Get a lock, reset the flag, start sending the message, and wait for the
329 // write to complete or fail. Be mindful of the shutdown flag.
330 std::unique_lock<std::mutex> lock(sentMutex);
331 sentSuccessfully = SendStatus::UnknownStatus;
332 StartWrite(&msg);
333 sentSuccessfullyCV.wait(lock, [&]() {
334 return shutdown || sentSuccessfully != SendStatus::UnknownStatus;
335 });
336 bool ret = sentSuccessfully == SendStatus::Success;
337 lock.unlock();
338 return ret;
339 });
340 }
341
342 // Call Finish to signal gRPC that we're done. gRPC will then call OnDone().
343 Finish(Status::OK);
344}
345
346/// When a client sends a message to a read port (write port on this end), start
347/// streaming messages until the client calls uncle and requests a cancellation.
348ServerWriteReactor<esi::cosim::Message> *
350 const ChannelDesc *request) {
351 getContext().getLogger().debug("cosim", "connect to client channel");
352 auto it = writePorts.find(request->name());
353 if (it == writePorts.end()) {
354 auto reactor = new RpcServerWriteReactor(nullptr);
355 reactor->Finish(Status(StatusCode::NOT_FOUND, "Unknown channel"));
356 return reactor;
357 }
358 return new RpcServerWriteReactor(it->second.get());
359}
360
361/// When a client sends a message to a write port (a read port on this end),
362/// simply locate the associated port, and write that message into its queue.
363ServerUnaryReactor *
364Impl::SendToServer(CallbackServerContext *context,
365 const esi::cosim::AddressedMessage *request,
366 esi::cosim::VoidMessage *response) {
367 auto reactor = context->DefaultReactor();
368 auto it = readPorts.find(request->channel_name());
369 if (it == readPorts.end()) {
370 reactor->Finish(Status(StatusCode::NOT_FOUND, "Unknown channel"));
371 return reactor;
372 }
373
374 std::string msgDataString = request->message().data();
375 MessageData data(reinterpret_cast<const uint8_t *>(msgDataString.data()),
376 msgDataString.size());
377 it->second->push(data);
378 reactor->Finish(Status::OK);
379 return reactor;
380}
381
382//===----------------------------------------------------------------------===//
383// RpcServer pass throughs to the actual implementations above.
384//===----------------------------------------------------------------------===//
386RpcServer::~RpcServer() = default;
387
388void RpcServer::setManifest(int esiVersion,
389 const std::vector<uint8_t> &compressedManifest) {
390 if (!impl)
391 throw std::runtime_error("Server not running");
392
393 impl->setManifest(esiVersion, compressedManifest);
394}
395
397 const std::string &type) {
398 if (!impl)
399 throw std::runtime_error("Server not running");
400 return impl->registerReadPort(name, type);
401}
402
404 const std::string &type) {
405 return impl->registerWritePort(name, type);
406}
407void RpcServer::run(int port) {
408 if (impl)
409 throw std::runtime_error("Server already running");
410 impl = std::make_unique<Impl>(ctxt, port);
411}
412void RpcServer::stop(uint32_t timeoutMS) {
413 if (!impl)
414 throw std::runtime_error("Server not running");
415 impl->stop(timeoutMS);
416}
417
419 if (!impl)
420 throw std::runtime_error("Server not running");
421 return impl->getPort();
422}
static std::unique_ptr< Context > context
static void writePort(uint16_t port)
Write the port number to a file.
Definition RpcServer.cpp:38
AcceleratorConnections, Accelerators, and Manifests must all share a context.
Definition Context.h:34
Logger & getLogger()
Definition Context.h:69
virtual void info(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an informational message.
Definition Logging.h:75
void debug(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report a debug message.
Definition Logging.h:83
Class to parse a manifest.
Definition Manifest.h:39
A logical chunk of data representing serialized data.
Definition Common.h:113
A ChannelPort which reads data from the accelerator.
Definition Ports.h:318
std::function< bool(MessageData)> callback
Backends call this callback when new data is available.
Definition Ports.h:378
Root class of the ESI type system.
Definition Types.h:34
A ChannelPort which sends data to the accelerator.
Definition Ports.h:206
virtual bool tryWriteImpl(const MessageData &data)=0
Implementation for tryWrite(). Subclasses must implement this.
virtual void writeImpl(const MessageData &)=0
Implementation for write(). Subclasses must implement this.
ServerUnaryReactor * ListChannels(CallbackServerContext *, const VoidMessage *, ListOfChannels *channelsOut) override
Load the list of channels into the response and fire it off.
ReadChannelPort & registerReadPort(const std::string &name, const std::string &type)
std::unique_ptr< Server > server
ServerWriteReactor< esi::cosim::Message > * ConnectToClientChannel(CallbackServerContext *context, const ChannelDesc *request) override
When a client sends a message to a read port (write port on this end), start streaming messages until...
std::vector< uint8_t > compressedManifest
Definition RpcServer.cpp:96
void setManifest(int esiVersion, const std::vector< uint8_t > &compressedManifest)
Definition RpcServer.cpp:62
Impl(Context &ctxt, int port)
Start a server on the given port. -1 means to let the OS pick a port.
ServerUnaryReactor * GetManifest(CallbackServerContext *context, const VoidMessage *, Manifest *response) override
std::map< std::string, std::unique_ptr< RpcServerWritePort > > writePorts
Definition RpcServer.cpp:98
void stop(uint32_t timeoutMS=0)
ServerUnaryReactor * SendToServer(CallbackServerContext *context, const esi::cosim::AddressedMessage *request, esi::cosim::VoidMessage *response) override
When a client sends a message to a write port (a read port on this end), simply locate the associated...
WriteChannelPort & registerWritePort(const std::string &name, const std::string &type)
std::map< std::string, std::unique_ptr< RpcServerReadPort > > readPorts
Definition RpcServer.cpp:97
RpcServer(Context &ctxt)
ReadChannelPort & registerReadPort(const std::string &name, const std::string &type)
Register a read or write port which communicates over RPC.
void run(int port=-1)
void stop(uint32_t timeoutMS=0)
void setManifest(int esiVersion, const std::vector< uint8_t > &compressedManifest)
Set the manifest and version.
std::unique_ptr< Impl > impl
Definition RpcServer.h:67
WriteChannelPort & registerWritePort(const std::string &name, const std::string &type)
Thread safe queue.
Definition Utils.h:39
void push(E... t)
Push onto the queue.
Definition Utils.h:55
Definition esi.py:1