CIRCT 22.0.0git
Loading...
Searching...
No Matches
RpcServer.cpp
Go to the documentation of this file.
1//===- RpcServer.cpp - Run a cosim server ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "esi/Context.h"
11#include "esi/Utils.h"
12
13#include "cosim.grpc.pb.h"
14
15#include <grpc/grpc.h>
16#include <grpcpp/security/server_credentials.h>
17#include <grpcpp/server.h>
18#include <grpcpp/server_builder.h>
19#include <grpcpp/server_context.h>
20
21#include <algorithm>
22#include <cassert>
23#include <cstdlib>
24#include <format>
25
26using namespace esi;
27using namespace esi::cosim;
28
29using grpc::CallbackServerContext;
30using grpc::Server;
31using grpc::ServerUnaryReactor;
32using grpc::ServerWriteReactor;
33using grpc::Status;
34using grpc::StatusCode;
35
36/// Write the port number to a file. Necessary when we are allowed to select our
37/// own port. We can't use stdout/stderr because the flushing semantics are
38/// undefined (as in `flush()` doesn't work on all simulators).
39static void writePort(uint16_t port) {
40 // "cosim.cfg" since we may want to include other info in the future.
41 FILE *fd = fopen("cosim.cfg", "w");
42 fprintf(fd, "port: %u\n", static_cast<unsigned int>(port));
43 fclose(fd);
44}
45
46namespace {
47class RpcServerReadPort;
48class RpcServerWritePort;
49} // namespace
50
53public:
54 Impl(Context &ctxt, int port);
55 ~Impl();
56
57 Context &getContext() { return ctxt; }
58
59 //===--------------------------------------------------------------------===//
60 // Internal API
61 //===--------------------------------------------------------------------===//
62
64 const std::vector<uint8_t> &compressedManifest) {
65 this->compressedManifest = compressedManifest;
66 this->esiVersion = esiVersion;
67 }
68
69 ReadChannelPort &registerReadPort(const std::string &name,
70 const std::string &type);
71 WriteChannelPort &registerWritePort(const std::string &name,
72 const std::string &type);
73
74 void stop(uint32_t timeoutMS = 0);
75
76 int getPort() { return port; }
77
78 //===--------------------------------------------------------------------===//
79 // RPC API implementations. See the .proto file for the API documentation.
80 //===--------------------------------------------------------------------===//
81
82 ServerUnaryReactor *GetManifest(CallbackServerContext *context,
83 const VoidMessage *,
84 Manifest *response) override;
85 ServerUnaryReactor *ListChannels(CallbackServerContext *, const VoidMessage *,
86 ListOfChannels *channelsOut) override;
87 ServerWriteReactor<esi::cosim::Message> *
88 ConnectToClientChannel(CallbackServerContext *context,
89 const ChannelDesc *request) override;
90 ServerUnaryReactor *SendToServer(CallbackServerContext *context,
91 const esi::cosim::AddressedMessage *request,
92 esi::cosim::VoidMessage *response) override;
93
94private:
97 std::vector<uint8_t> compressedManifest;
98 std::map<std::string, std::unique_ptr<RpcServerReadPort>> readPorts;
99 std::map<std::string, std::unique_ptr<RpcServerWritePort>> writePorts;
100 int port = -1;
101 std::unique_ptr<Server> server;
102};
104
105//===----------------------------------------------------------------------===//
106// Read and write ports
107//
108// Implemented as simple queues which the RPC server writes to and reads from.
109//===----------------------------------------------------------------------===//
110
111namespace {
112/// Implements a simple read queue. The RPC server will push messages into this
113/// as appropriate.
114class RpcServerReadPort : public ReadChannelPort {
115public:
116 RpcServerReadPort(Type *type) : ReadChannelPort(type) {}
117
118 /// Internal call. Push a message FROM the RPC client to the read port.
119 void push(MessageData &data) {
120 while (!callback(data))
121 std::this_thread::sleep_for(std::chrono::milliseconds(1));
122 }
123};
124
125/// Implements a simple write queue. The RPC server will pull messages from this
126/// as appropriate. Note that this could be more performant if a callback is
127/// used. This would have more complexity as when a client disconnects the
128/// outstanding messages will need somewhere to be held until the next client
129/// connects. For now, it's simpler to just have the server poll the queue.
130class RpcServerWritePort : public WriteChannelPort {
131public:
132 RpcServerWritePort(Type *type) : WriteChannelPort(type) {}
133
135
136protected:
137 void writeImpl(const MessageData &data) override { writeQueue.push(data); }
138 bool tryWriteImpl(const MessageData &data) override {
139 writeQueue.push(data);
140 return true;
141 }
142};
143} // namespace
144
145//===----------------------------------------------------------------------===//
146// RPC server implementations
147//===----------------------------------------------------------------------===//
148
149/// Start a server on the given port. -1 means to let the OS pick a port.
150Impl::Impl(Context &ctxt, int port) : ctxt(ctxt), esiVersion(-1) {
151 grpc::ServerBuilder builder;
152 std::string server_address("127.0.0.1:" + std::to_string(port));
153 // TODO: use secure credentials. Not so bad for now since we only accept
154 // connections on localhost.
155 builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(),
156 &port);
157 builder.RegisterService(this);
158 server = builder.BuildAndStart();
159 if (!server)
160 throw std::runtime_error("Failed to start server on " + server_address);
162 this->port = port;
163 ctxt.getLogger().info("cosim", "Server listening on 127.0.0.1:" +
164 std::to_string(port));
165}
166
167void Impl::stop(uint32_t timeoutMS) {
168 // Disconnect all the ports.
169 for (auto &[name, port] : readPorts)
170 port->disconnect();
171 for (auto &[name, port] : writePorts)
172 port->disconnect();
173
174 // Shutdown the server and wait for it to finish.
175 if (timeoutMS > 0)
176 server->Shutdown(gpr_time_add(
177 gpr_now(GPR_CLOCK_REALTIME),
178 gpr_time_from_millis(static_cast<int>(timeoutMS), GPR_TIMESPAN)));
179 else
180 server->Shutdown();
181
182 server->Wait();
183 server = nullptr;
184}
185
187 if (server)
188 stop();
189}
190
192 const std::string &type) {
193 auto port = new RpcServerReadPort(new Type(type));
194 readPorts.emplace(name, port);
195 return *port;
196}
198 const std::string &type) {
199 auto port = new RpcServerWritePort(new Type(type));
200 writePorts.emplace(name, port);
201 return *port;
202}
203
204ServerUnaryReactor *Impl::GetManifest(CallbackServerContext *context,
205 const VoidMessage *, Manifest *response) {
206 response->set_esi_version(esiVersion);
207 response->set_compressed_manifest(compressedManifest.data(),
208 compressedManifest.size());
209 ServerUnaryReactor *reactor = context->DefaultReactor();
210 reactor->Finish(Status::OK);
211 return reactor;
212}
213
214/// Load the list of channels into the response and fire it off.
215ServerUnaryReactor *Impl::ListChannels(CallbackServerContext *context,
216 const VoidMessage *,
217 ListOfChannels *channelsOut) {
218 for (auto &[name, port] : readPorts) {
219 auto *channel = channelsOut->add_channels();
220 channel->set_name(name);
221 channel->set_type(port->getType()->getID());
222 channel->set_dir(ChannelDesc::Direction::ChannelDesc_Direction_TO_SERVER);
223 }
224 for (auto &[name, port] : writePorts) {
225 auto *channel = channelsOut->add_channels();
226 channel->set_name(name);
227 channel->set_type(port->getType()->getID());
228 channel->set_dir(ChannelDesc::Direction::ChannelDesc_Direction_TO_CLIENT);
229 }
230
231 // The default reactor is basically to just finish the RPC call as if we're
232 // implementing the RPC function as a blocking call.
233 auto reactor = context->DefaultReactor();
234 reactor->Finish(Status::OK);
235 return reactor;
236}
237
238namespace {
239/// When a client connects to a read port (on its end, a write port on this
240/// end), construct one of these to poll the corresponding write port on this
241/// side and forward the messages.
242class RpcServerWriteReactor : public ServerWriteReactor<esi::cosim::Message> {
243public:
244 RpcServerWriteReactor(RpcServerWritePort *writePort)
245 : writePort(writePort), sentSuccessfully(SendStatus::UnknownStatus),
246 shutdown(false), onDoneCalled(false) {
247 myThread = std::thread(&RpcServerWriteReactor::threadLoop, this);
248 }
249
250 // gRPC manages the lifecycle of this object. OnDone() is called when gRPC is
251 // completely done with this reactor. We must wait for our thread to finish
252 // before deleting. See:
253 // https://github.com/grpc/grpc/blob/4795c5e69b25e8c767b498bea784da0ef8c96fd5/examples/cpp/route_guide/route_guide_callback_server.cc#L120
254 void OnDone() override {
255 // Signal shutdown and wake up any waiting threads.
256 {
257 std::scoped_lock<std::mutex> lock(sentMutex);
258 shutdown = true;
259 onDoneCalled = true;
260 }
261 sentSuccessfullyCV.notify_one();
262 onDoneCV.notify_one();
263
264 // Wait for the thread to finish before self-deleting.
265 if (myThread.joinable())
266 myThread.join();
267
268 delete this;
269 }
270
271 void OnWriteDone(bool ok) override {
272 std::scoped_lock<std::mutex> lock(sentMutex);
273 sentSuccessfully = ok ? SendStatus::Success : SendStatus::Failure;
274 sentSuccessfullyCV.notify_one();
275 }
276
277 void OnCancel() override {
278 std::scoped_lock<std::mutex> lock(sentMutex);
279 shutdown = true;
280 sentSuccessfully = SendStatus::Disconnect;
281 sentSuccessfullyCV.notify_one();
282 }
283
284private:
285 /// The polling loop.
286 void threadLoop();
287 /// The polling thread.
288 std::thread myThread;
289
290 /// Assoicated write port on this side. (Read port on the client side.)
291 RpcServerWritePort *writePort;
292
293 /// Mutex to protect the sentSuccessfully flag and shutdown state.
294 std::mutex sentMutex;
295 enum SendStatus { UnknownStatus, Success, Failure, Disconnect };
296 volatile SendStatus sentSuccessfully;
297 std::condition_variable sentSuccessfullyCV;
298
299 std::atomic<bool> shutdown;
300
301 /// Condition variable to wait for OnDone to be called.
302 bool onDoneCalled;
303 std::condition_variable onDoneCV;
304};
305
306} // namespace
307
308void RpcServerWriteReactor::threadLoop() {
309 while (!shutdown && sentSuccessfully != SendStatus::Disconnect) {
310 // TODO: adapt this to a new notification mechanism which is forthcoming.
311 if (!writePort || writePort->writeQueue.empty()) {
312 std::this_thread::sleep_for(std::chrono::microseconds(100));
313 continue;
314 }
315
316 // This lambda will get called with the message at the front of the queue.
317 // If the send is successful, return true to pop it. We don't know, however,
318 // if the message was sent successfully in this thread. It's only when the
319 // `OnWriteDone` method is called by gRPC that we know. Use locking and
320 // condition variables to orchestrate this confirmation.
321 writePort->writeQueue.pop([this](const MessageData &data) -> bool {
322 if (shutdown)
323 return false;
324
325 esi::cosim::Message msg;
326 msg.set_data(reinterpret_cast<const char *>(data.getBytes()),
327 data.getSize());
328
329 // Get a lock, reset the flag, start sending the message, and wait for the
330 // write to complete or fail. Be mindful of the shutdown flag.
331 std::unique_lock<std::mutex> lock(sentMutex);
332 sentSuccessfully = SendStatus::UnknownStatus;
333 StartWrite(&msg);
334 sentSuccessfullyCV.wait(lock, [&]() {
335 return shutdown || sentSuccessfully != SendStatus::UnknownStatus;
336 });
337 bool ret = sentSuccessfully == SendStatus::Success;
338 lock.unlock();
339 return ret;
340 });
341 }
342
343 // Call Finish to signal gRPC that we're done. gRPC will then call OnDone().
344 Finish(Status::OK);
345}
346
347/// When a client sends a message to a read port (write port on this end), start
348/// streaming messages until the client calls uncle and requests a cancellation.
349ServerWriteReactor<esi::cosim::Message> *
351 const ChannelDesc *request) {
352 getContext().getLogger().debug("cosim", "connect to client channel");
353 auto it = writePorts.find(request->name());
354 if (it == writePorts.end()) {
355 auto reactor = new RpcServerWriteReactor(nullptr);
356 reactor->Finish(Status(StatusCode::NOT_FOUND, "Unknown channel"));
357 return reactor;
358 }
359 return new RpcServerWriteReactor(it->second.get());
360}
361
362/// When a client sends a message to a write port (a read port on this end),
363/// simply locate the associated port, and write that message into its queue.
364ServerUnaryReactor *
365Impl::SendToServer(CallbackServerContext *context,
366 const esi::cosim::AddressedMessage *request,
367 esi::cosim::VoidMessage *response) {
368 auto reactor = context->DefaultReactor();
369 auto it = readPorts.find(request->channel_name());
370 if (it == readPorts.end()) {
371 reactor->Finish(Status(StatusCode::NOT_FOUND, "Unknown channel"));
372 return reactor;
373 }
374
375 std::string msgDataString = request->message().data();
376 MessageData data(reinterpret_cast<const uint8_t *>(msgDataString.data()),
377 msgDataString.size());
378 try {
380 "cosim",
381 std::format("Channel '{}': Received message; pushing data to read port",
382 request->channel_name()));
383 it->second->push(data);
384 } catch (const std::exception &e) {
386 "cosim",
387 std::format("Channel '{}': Error pushing message to read port: {}",
388 request->channel_name(), e.what()));
389 reactor->Finish(
390 Status(StatusCode::INTERNAL, "Error pushing message to port"));
391 return reactor;
392 }
393
394 reactor->Finish(Status::OK);
395 return reactor;
396}
397
398//===----------------------------------------------------------------------===//
399// RpcServer pass throughs to the actual implementations above.
400//===----------------------------------------------------------------------===//
402RpcServer::~RpcServer() = default;
403
404void RpcServer::setManifest(int esiVersion,
405 const std::vector<uint8_t> &compressedManifest) {
406 if (!impl)
407 throw std::runtime_error("Server not running");
408
409 impl->setManifest(esiVersion, compressedManifest);
410}
411
413 const std::string &type) {
414 if (!impl)
415 throw std::runtime_error("Server not running");
416 return impl->registerReadPort(name, type);
417}
418
420 const std::string &type) {
421 return impl->registerWritePort(name, type);
422}
423void RpcServer::run(int port) {
424 if (impl)
425 throw std::runtime_error("Server already running");
426 impl = std::make_unique<Impl>(ctxt, port);
427}
428void RpcServer::stop(uint32_t timeoutMS) {
429 if (!impl)
430 throw std::runtime_error("Server not running");
431 impl->stop(timeoutMS);
432}
433
435 if (!impl)
436 throw std::runtime_error("Server not running");
437 return impl->getPort();
438}
static std::unique_ptr< Context > context
static void writePort(uint16_t port)
Write the port number to a file.
Definition RpcServer.cpp:39
AcceleratorConnections, Accelerators, and Manifests must all share a context.
Definition Context.h:34
Logger & getLogger()
Definition Context.h:69
virtual void error(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an error.
Definition Logging.h:64
virtual void info(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an informational message.
Definition Logging.h:75
void debug(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report a debug message.
Definition Logging.h:83
Class to parse a manifest.
Definition Manifest.h:39
A logical chunk of data representing serialized data.
Definition Common.h:113
A ChannelPort which reads data from the accelerator.
Definition Ports.h:318
std::function< bool(MessageData)> callback
Backends call this callback when new data is available.
Definition Ports.h:378
Root class of the ESI type system.
Definition Types.h:34
A ChannelPort which sends data to the accelerator.
Definition Ports.h:206
virtual bool tryWriteImpl(const MessageData &data)=0
Implementation for tryWrite(). Subclasses must implement this.
virtual void writeImpl(const MessageData &)=0
Implementation for write(). Subclasses must implement this.
ServerUnaryReactor * ListChannels(CallbackServerContext *, const VoidMessage *, ListOfChannels *channelsOut) override
Load the list of channels into the response and fire it off.
ReadChannelPort & registerReadPort(const std::string &name, const std::string &type)
std::unique_ptr< Server > server
ServerWriteReactor< esi::cosim::Message > * ConnectToClientChannel(CallbackServerContext *context, const ChannelDesc *request) override
When a client sends a message to a read port (write port on this end), start streaming messages until...
std::vector< uint8_t > compressedManifest
Definition RpcServer.cpp:97
void setManifest(int esiVersion, const std::vector< uint8_t > &compressedManifest)
Definition RpcServer.cpp:63
Impl(Context &ctxt, int port)
Start a server on the given port. -1 means to let the OS pick a port.
ServerUnaryReactor * GetManifest(CallbackServerContext *context, const VoidMessage *, Manifest *response) override
std::map< std::string, std::unique_ptr< RpcServerWritePort > > writePorts
Definition RpcServer.cpp:99
void stop(uint32_t timeoutMS=0)
ServerUnaryReactor * SendToServer(CallbackServerContext *context, const esi::cosim::AddressedMessage *request, esi::cosim::VoidMessage *response) override
When a client sends a message to a write port (a read port on this end), simply locate the associated...
WriteChannelPort & registerWritePort(const std::string &name, const std::string &type)
std::map< std::string, std::unique_ptr< RpcServerReadPort > > readPorts
Definition RpcServer.cpp:98
RpcServer(Context &ctxt)
ReadChannelPort & registerReadPort(const std::string &name, const std::string &type)
Register a read or write port which communicates over RPC.
void run(int port=-1)
void stop(uint32_t timeoutMS=0)
void setManifest(int esiVersion, const std::vector< uint8_t > &compressedManifest)
Set the manifest and version.
std::unique_ptr< Impl > impl
Definition RpcServer.h:67
WriteChannelPort & registerWritePort(const std::string &name, const std::string &type)
Thread safe queue.
Definition Utils.h:39
void push(E... t)
Push onto the queue.
Definition Utils.h:55
Definition esi.py:1