CIRCT 23.0.0git
Loading...
Searching...
No Matches
RpcServer.cpp
Go to the documentation of this file.
1//===- RpcServer.cpp - Run a cosim server ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation of the cosim RPC server over the WebSocket + JSON protocol.
10// The wire protocol is documented in cosim-protocol.md.
11//
12//===----------------------------------------------------------------------===//
13
15#include "esi/Context.h"
16#include "esi/Utils.h"
17#include "esi/backends/RpcClient.h" // for ChannelDirection
18
19#include "RpcWire.h"
20
21#include <ixwebsocket/IXBase64.h>
22#include <ixwebsocket/IXNetSystem.h>
23#include <ixwebsocket/IXWebSocket.h>
24#include <ixwebsocket/IXWebSocketServer.h>
25#include <nlohmann/json.hpp>
26
27#include <atomic>
28#include <cassert>
29#include <condition_variable>
30#include <cstdint>
31#include <cstdio>
32#include <cstring>
33#include <format>
34#include <map>
35#include <mutex>
36#include <optional>
37#include <string>
38#include <thread>
39#include <unordered_map>
40#include <unordered_set>
41#include <vector>
42
43#ifdef _WIN32
44#include <winsock2.h>
45#include <ws2tcpip.h>
46#else
47#include <arpa/inet.h>
48#include <netinet/in.h>
49#include <sys/socket.h>
50#include <unistd.h>
51#endif
52
53using namespace esi;
54using namespace esi::cosim;
55using json = nlohmann::json;
56
57namespace {
58
59//===----------------------------------------------------------------------===//
60// Helpers
61//===----------------------------------------------------------------------===//
62
63/// Write the bound port number to a file so callers (typically esi-cosim) can
64/// discover it. Necessary when the OS picks the port, since the simulator's
65/// stdout/stderr buffering is undefined.
66static void writePortFile(uint16_t port) {
67 FILE *fd = fopen("cosim.cfg", "w");
68 if (!fd)
69 return;
70 fprintf(fd, "port: %u\n", static_cast<unsigned int>(port));
71 fclose(fd);
72}
73
74/// Pre-bind a temporary loopback socket to port 0 to discover an OS-assigned
75/// ephemeral port. IXWebSocket exposes the port a user passed in but does not
76/// query `getsockname` after binding, so for the "let the OS pick" path we have
77/// to find a free port ourselves and hand it to IX. SO_REUSEADDR + immediate
78/// close keep the race window minimal in practice.
79static int pickEphemeralPort() {
80#ifdef _WIN32
81 SOCKET fd = socket(AF_INET, SOCK_STREAM, 0);
82 if (fd == INVALID_SOCKET)
83 return -1;
84#else
85 int fd = socket(AF_INET, SOCK_STREAM, 0);
86 if (fd < 0)
87 return -1;
88#endif
89 int enable = 1;
90 setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
91 reinterpret_cast<const char *>(&enable), sizeof(enable));
92 sockaddr_in addr{};
93 addr.sin_family = AF_INET;
94 addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
95 addr.sin_port = 0;
96 if (bind(fd, reinterpret_cast<sockaddr *>(&addr), sizeof(addr)) < 0) {
97#ifdef _WIN32
98 closesocket(fd);
99#else
100 close(fd);
101#endif
102 return -1;
103 }
104 sockaddr_in bound{};
105 socklen_t len = sizeof(bound);
106 int rc = getsockname(fd, reinterpret_cast<sockaddr *>(&bound), &len);
107#ifdef _WIN32
108 closesocket(fd);
109#else
110 close(fd);
111#endif
112 if (rc < 0)
113 return -1;
114 return ntohs(bound.sin_port);
115}
116
117class RpcServerReadPort;
118class RpcServerWritePort;
119
120} // namespace
121
122//===----------------------------------------------------------------------===//
123// RpcServer::Impl - private implementation
124//===----------------------------------------------------------------------===//
125
127public:
128 Impl(Context &ctxt, int port);
129 ~Impl();
130
131 Context &getContext() { return ctxt; }
132
133 void setManifest(int esiVersion,
134 const std::vector<uint8_t> &compressedManifest);
135
136 ReadChannelPort &registerReadPort(const std::string &name,
137 const std::string &type);
138 WriteChannelPort &registerWritePort(const std::string &name,
139 const std::string &type);
140
141 void stop(uint32_t timeoutMS);
142 int getPort() const { return boundPort; }
143
144 // Dirty-set doorbell for the transport thread. Public so each port's
145 // TSQueue can pass `[&impl, id]{ impl.readyIds.markReady(id); }` as its
146 // push-notifier; also called directly from `handleSubscribe` to flush
147 // anything queued before the subscription arrived.
149
150private:
151 // Reuse the public direction enum from the client side rather than
152 // duplicating it; both ends of the cosim transport agree on the same two
153 // values.
155 struct ChannelInfo {
156 uint64_t id;
157 std::string name;
158 std::string typeId;
160 RpcServerReadPort *readPort = nullptr;
161 RpcServerWritePort *writePort = nullptr;
162 };
163
165 ix::WebSocket *ws;
166 // The set of `to_client` channel ids the client subscribed to.
167 std::mutex subscribedMutex;
168 std::unordered_set<uint64_t> subscribed;
169 // True once `hello` has been answered.
170 bool helloDone = false;
171 };
172
174
175 // Manifest state. setManifest() flips manifestReady and broadcasts the CV;
176 // any in-flight `hello` handler blocks on this until it is set.
177 std::mutex manifestMutex;
178 std::condition_variable manifestReadyCV;
179 int esiVersion = -1;
180 std::vector<uint8_t> compressedManifest;
181 bool manifestReady = false;
182
183 // Channel table is keyed by name in the public API and by id on the wire.
184 // Ports are owned here; ChannelInfo holds non-owning pointers.
185 std::mutex channelsMutex;
186 std::map<std::string, std::unique_ptr<RpcServerReadPort>> readPorts;
187 std::map<std::string, std::unique_ptr<RpcServerWritePort>> writePorts;
188 std::unordered_map<uint64_t, ChannelInfo> channelById;
189 std::unordered_map<std::string, uint64_t> idByName;
190 uint64_t nextChannelId = 0;
191
192 // Session state. v3 of the protocol allows a single concurrent client.
193 std::mutex sessionMutex;
194 std::unique_ptr<ClientSession> session;
195
196 // Transport thread; drains `readyIds` and dispatches per channel id.
197 std::thread transportThread;
198
199 // The IXWebSocket server.
200 std::unique_ptr<ix::WebSocketServer> server;
201 int boundPort = -1;
202
203 // Connection callbacks.
204 void onClientMessage(std::shared_ptr<ix::ConnectionState> state,
205 ix::WebSocket &ws, const ix::WebSocketMessagePtr &msg);
206 void onOpen(ix::WebSocket &ws);
207 void onClose();
208 void handleBinaryFrame(const std::string &data);
209 void handleControlFrame(ix::WebSocket &ws, const std::string &text);
210 void handleHello(ix::WebSocket &ws, uint64_t requestId, const json &params);
211 void handleSubscribe(ix::WebSocket &ws, uint64_t requestId,
212 const json &params);
213 void handleUnsubscribe(ix::WebSocket &ws, uint64_t requestId,
214 const json &params);
215 void sendResult(ix::WebSocket &ws, uint64_t requestId, const json &result);
216 void sendError(ix::WebSocket &ws, uint64_t requestId, const std::string &code,
217 const std::string &message);
218
219 void transportLoop();
220
221 // Cross-thread fault propagation out of the IX network thread. See
222 // FaultStash docs in RpcWire.h.
224};
225
227
228//===----------------------------------------------------------------------===//
229// Port implementations
230//
231// Read and write ports are simple queues; the RPC server pushes/pops as
232// appropriate. These mirror the previous gRPC implementation: they are
233// transport-agnostic except that write port writes ring the sender doorbell.
234//===----------------------------------------------------------------------===//
235
236namespace {
237/// Read port for "to server" channels.
238///
239/// The cosim transport hands inbound frames synchronously from the IX
240/// network thread, which is shared across every channel on a connection.
241/// Any back-pressure on a per-port consumer would stall that thread and risk
242/// a cross-channel deadlock when an accelerator's flow control requires
243/// ordering across channels. So we *force* polling-mode `connect` to use an
244/// unbounded internal queue regardless of what the caller passes --
245/// `ReadChannelPort::pollingState` then acts as our buffer and
246/// `invokeCallback` is guaranteed non-blocking. The unbounded queue mirrors
247/// the existing `to_client` write queue and is acceptable for a cosim
248/// driver.
249class RpcServerReadPort : public ReadChannelPort {
250public:
252
253 /// Polling-mode connect: force the internal queue to be unbounded
254 /// (`bufferSize == 0`) regardless of what the caller asked for.
255 void connect(const ConnectOptions &options = {}) override {
256 ConnectOptions forced = options;
257 forced.bufferSize = 0;
259 }
260
261 /// Hand one inbound frame to the user callback. Returns `false` only when
262 /// the port is disconnected (typically during shutdown), in which case the
263 /// caller should drop the frame.
264 bool deliver(const MessageData &data) {
265 std::unique_ptr<SegmentedMessageData> msg =
266 std::make_unique<MessageData>(data);
267 return invokeCallback(msg);
268 }
269};
270
271/// Write queue for "to client" channels. Writes go into the per-port TSQueue;
272/// the queue's notifier hook rings the transport doorbell so the transport
273/// thread wakes up and drains across the WebSocket. DPI threads must never
274/// block on I/O, so this path is strictly non-blocking.
275class RpcServerWritePort : public WriteChannelPort {
276public:
277 RpcServerWritePort(Type *type, Impl &impl, uint64_t channelId)
278 : WriteChannelPort(type), channelId(channelId),
279 writeQueue([&impl, channelId] { impl.readyIds.markReady(channelId); }) {
280 }
281
282 uint64_t getChannelId() const { return channelId; }
283 uint64_t channelId;
285
286protected:
287 // TODO: TSQueue is unbounded so if there's no client subscibed it'll fill up
288 // memory. We should add some backpressure mechanism here to avoid that.
289 void writeImpl(const MessageData &data) override { writeQueue.push(data); }
290 bool tryWriteImpl(const MessageData &data) override {
291 writeImpl(data);
292 return true;
293 }
294};
295} // namespace
296
297//===----------------------------------------------------------------------===//
298// Impl - server lifecycle
299//===----------------------------------------------------------------------===//
300
301Impl::Impl(Context &ctxt, int port) : ctxt(ctxt) {
302 // On Windows, `ix::initNetSystem()` calls `WSAStartup` and returns false if
303 // that fails. On other platforms it's a no-op that always returns true.
304 if (!ix::initNetSystem())
305 throw std::runtime_error(
306 "RpcServer: ix::initNetSystem() failed (WSAStartup)");
307
308 // Resolve port 0 / negative request to an OS-assigned ephemeral port,
309 // since IXWebSocket does not expose the bound port after the fact.
310 int requestedPort = port;
311 if (requestedPort <= 0) {
312 requestedPort = pickEphemeralPort();
313 if (requestedPort <= 0)
314 throw std::runtime_error(
315 "RpcServer: failed to obtain an ephemeral TCP port");
316 }
317
318 const std::string host = "127.0.0.1";
319 server = std::make_unique<ix::WebSocketServer>(requestedPort, host);
320 server->disablePerMessageDeflate();
321
322 server->setOnClientMessageCallback(
323 [this](std::shared_ptr<ix::ConnectionState> state, ix::WebSocket &ws,
324 const ix::WebSocketMessagePtr &msg) {
325 onClientMessage(std::move(state), ws, msg);
326 });
327
328 auto res = server->listen();
329 if (!res.first)
330 throw std::runtime_error("RpcServer: listen failed: " + res.second);
331
332 server->start();
333 boundPort = requestedPort;
334 writePortFile(static_cast<uint16_t>(boundPort));
335 ctxt.getLogger().info("cosim", std::format("Server listening on {}:{}", host,
336 static_cast<unsigned>(boundPort)));
337
338 // Start the always-on transport thread that drains every port's queue. It
339 // lives the entire lifetime of the server and sleeps on `readyIds`'s
340 // internal CV when there is nothing to do regardless of whether a client
341 // is connected.
342 transportThread = std::thread([this] { transportLoop(); });
343}
344
346 if (server) {
347 // A pending fault from the network thread is interesting, but throwing
348 // out of a destructor is worse than swallowing it. The user can still
349 // observe the fault by calling stop() explicitly before letting the
350 // server destruct; if they didn't, log and continue.
351 try {
352 stop(0);
353 } catch (const std::exception &e) {
355 "cosim",
356 std::string("Suppressed exception during ~RpcServer::Impl: ") +
357 e.what());
358 } catch (...) {
360 "cosim", "Suppressed non-std::exception during ~RpcServer::Impl");
361 }
362 }
363}
364
365void Impl::setManifest(int esiVersion,
366 const std::vector<uint8_t> &compressedManifest) {
368 {
369 std::lock_guard<std::mutex> lock(manifestMutex);
370 this->esiVersion = esiVersion;
371 this->compressedManifest = compressedManifest;
372 manifestReady = true;
373 }
374 manifestReadyCV.notify_all();
375}
376
378 const std::string &type) {
380 std::lock_guard<std::mutex> lock(channelsMutex);
381 uint64_t id = nextChannelId++;
382 auto port = std::make_unique<RpcServerReadPort>(new Type(type));
383 RpcServerReadPort *raw = port.get();
384 readPorts.emplace(name, std::move(port));
385
386 ChannelInfo info{id, name, type, ChannelDirection::ToServer, raw, nullptr};
387 channelById.emplace(id, std::move(info));
388 idByName.emplace(name, id);
389 return *raw;
390}
391
393 const std::string &type) {
395 std::lock_guard<std::mutex> lock(channelsMutex);
396 uint64_t id = nextChannelId++;
397 auto port = std::make_unique<RpcServerWritePort>(new Type(type), *this, id);
398 RpcServerWritePort *raw = port.get();
399 writePorts.emplace(name, std::move(port));
400
401 ChannelInfo info{id, name, type, ChannelDirection::ToClient, nullptr, raw};
402 channelById.emplace(id, std::move(info));
403 idByName.emplace(name, id);
404 return *raw;
405}
406
407void Impl::stop(uint32_t /*timeoutMS*/) {
408 // Disconnect ports first so any in-flight DPI writes see a closed channel.
409 {
410 std::lock_guard<std::mutex> lock(channelsMutex);
411 for (auto &[name, port] : readPorts)
412 port->disconnect();
413 for (auto &[name, port] : writePorts)
414 port->disconnect();
415 }
416
417 // Retire the transport thread.
418 if (transportThread.joinable()) {
420 transportThread.join();
421 // NB: there's no explicit "clear readyIds" step: each port's queue
422 // contents and the corresponding dirty markers belong to the ports,
423 // not to the transport thread, and the server doesn't restart in-place.
424 }
425
426 if (server) {
427 server->stop();
428 server.reset();
429 }
430
431 {
432 std::lock_guard<std::mutex> lock(sessionMutex);
433 session.reset();
434 }
435
436 // Surface any fault the IX network thread caught while we were running.
437 // Done last so the rest of the shutdown sequence completes regardless.
439}
440
441//===----------------------------------------------------------------------===//
442// Impl - WebSocket message dispatch
443//===----------------------------------------------------------------------===//
444
445void Impl::onClientMessage(std::shared_ptr<ix::ConnectionState> /*state*/,
446 ix::WebSocket &ws,
447 const ix::WebSocketMessagePtr &msg) {
448 switch (msg->type) {
449 case ix::WebSocketMessageType::Open: {
450 // Single-client model: if a session is already active, send an unsolicited
451 // JSON error frame so the new client gets an actionable, application-level
452 // reason, then close with 1013 ("Try Again Later") rather than 1011
453 // ("Internal Error") -- the latter falsely implies a server-side bug.
454 std::lock_guard<std::mutex> lock(sessionMutex);
455 if (session) {
457 "cosim", "Rejecting additional client; one already connected");
458 json err;
459 err["type"] = "error";
460 err["error"] = {{"code", "server_busy"},
461 {"message", "cosim server allows only one client at a "
462 "time; another client is already connected"}};
463 ws.sendUtf8Text(err.dump());
464 // 1013 = Try Again Later (RFC 6455 §7.4).
465 ws.close(1013, "cosim server busy: another client is already connected");
466 return;
467 }
468 session = std::make_unique<ClientSession>();
469 session->ws = &ws;
470 ctxt.getLogger().debug("cosim", "Client connected");
471 return;
472 }
473 case ix::WebSocketMessageType::Close:
474 case ix::WebSocketMessageType::Error: {
476 "cosim", std::format("Client disconnected: {}",
477 msg->type == ix::WebSocketMessageType::Error
478 ? msg->errorInfo.reason
479 : msg->closeInfo.reason));
480 onClose();
481 return;
482 }
483 case ix::WebSocketMessageType::Message:
484 // An exception escaping this callback would kill IX's network thread.
485 // Stash the first one so the next public RpcServer method rethrows it
486 // on the user's thread.
487 try {
488 if (msg->binary)
489 handleBinaryFrame(msg->str);
490 else
491 handleControlFrame(ws, msg->str);
492 } catch (...) {
493 faultStash.record(std::current_exception());
494 }
495 return;
496 default:
497 return;
498 }
499}
500
502 std::lock_guard<std::mutex> lock(sessionMutex);
503 session.reset();
504}
505
506void Impl::handleBinaryFrame(const std::string &data) {
507 uint64_t channelId;
508 const uint8_t *payloadBytes;
509 size_t payloadSize;
510 if (!parseDataFrame(data, channelId, payloadBytes, payloadSize)) {
511 ctxt.getLogger().error("cosim",
512 "Received binary frame shorter than 8-byte header");
513 return;
514 }
515
516 RpcServerReadPort *port = nullptr;
517 {
518 std::lock_guard<std::mutex> lock(channelsMutex);
519 auto it = channelById.find(channelId);
520 if (it == channelById.end() ||
521 it->second.direction != ChannelDirection::ToServer)
522 port = nullptr;
523 else
524 port = it->second.readPort;
525 }
526 if (!port) {
528 "cosim", std::format("Binary frame for unknown to-server channel id {}",
529 channelId));
530 return;
531 }
532
533 MessageData payload(payloadBytes, payloadSize);
534 // RpcServerReadPort always has an unbounded internal queue so `deliver` is
535 // non-blocking: it enqueues into `ReadChannelPort`'s internal polling buffer
536 // and returns. A `false` return means the port has been disconnected
537 // (shutdown); just drop the frame in that case.
538 if (!port->deliver(payload))
540 "cosim",
541 std::format("Dropped {} bytes for channel id {}: port not connected",
542 payload.getSize(), channelId));
543}
544
545void Impl::handleControlFrame(ix::WebSocket &ws, const std::string &text) {
546 json req;
547 try {
548 req = json::parse(text);
549 } catch (const std::exception &e) {
551 "cosim", std::format("Failed to parse control frame: {}", e.what()));
552 sendError(ws, 0, "protocol_error",
553 std::string("Failed to parse JSON: ") + e.what());
554 faultStash.record(std::current_exception());
555 return;
556 }
557
558 auto typeIt = req.find("type");
559 if (typeIt == req.end() || !typeIt->is_string() ||
560 typeIt->get<std::string>() != "request") {
561 sendError(ws, 0, "protocol_error",
562 "Control frame missing \"type\":\"request\"");
563 return;
564 }
565 uint64_t requestId = 0;
566 if (auto idIt = req.find("request_id"); idIt != req.end()) {
567 // Accept only unsigned integers.
568 if (!idIt->is_number_unsigned()) {
569 sendError(ws, 0, "protocol_error",
570 "\"request_id\" must be an unsigned integer");
571 return;
572 }
573 try {
574 requestId = idIt->get<uint64_t>();
575 } catch (const std::exception &e) {
576 sendError(ws, 0, "protocol_error",
577 std::string("Invalid \"request_id\": ") + e.what());
578 return;
579 }
580 }
581 auto methodIt = req.find("method");
582 if (methodIt == req.end() || !methodIt->is_string()) {
583 sendError(ws, requestId, "protocol_error", "Missing \"method\"");
584 return;
585 }
586 std::string method = methodIt->get<std::string>();
587 json params = req.value("params", json::object());
588
589 if (method == "hello")
590 handleHello(ws, requestId, params);
591 else if (method == "subscribe")
592 handleSubscribe(ws, requestId, params);
593 else if (method == "unsubscribe")
594 handleUnsubscribe(ws, requestId, params);
595 else
596 sendError(ws, requestId, "protocol_error", "Unknown method: " + method);
597}
598
599//===----------------------------------------------------------------------===//
600// Impl - control methods
601//===----------------------------------------------------------------------===//
602
603void Impl::handleHello(ix::WebSocket &ws, uint64_t requestId,
604 const json & /*params*/) {
605 // Block until the manifest has been set. This replaces the gRPC-era poll
606 // loop on the client side.
607 {
608 std::unique_lock<std::mutex> lock(manifestMutex);
609 manifestReadyCV.wait(lock, [&] { return manifestReady; });
610 }
611
612 json result;
613 result["protocol_version"] = 3;
614 {
615 std::lock_guard<std::mutex> lock(manifestMutex);
616 result["esi_version"] = esiVersion;
617 result["compressed_manifest_b64"] = macaron::Base64::Encode(
618 std::string(reinterpret_cast<const char *>(compressedManifest.data()),
619 compressedManifest.size()));
620 }
621
622 json channelsJson = json::array();
623 {
624 std::lock_guard<std::mutex> lock(channelsMutex);
625 // Emit in id order for deterministic output.
626 for (uint64_t i = 0; i < nextChannelId; ++i) {
627 auto it = channelById.find(i);
628 if (it == channelById.end())
629 continue;
630 const ChannelInfo &info = it->second;
631 json c;
632 c["channel_id"] = info.id;
633 c["name"] = info.name;
634 c["type"] = info.typeId;
635 c["direction"] = info.direction == ChannelDirection::ToServer
636 ? "to_server"
637 : "to_client";
638 channelsJson.push_back(std::move(c));
639 }
640 }
641 result["channels"] = std::move(channelsJson);
642
643 {
644 std::lock_guard<std::mutex> lock(sessionMutex);
645 if (session)
646 session->helloDone = true;
647 }
648
649 sendResult(ws, requestId, result);
650}
651
652void Impl::handleSubscribe(ix::WebSocket &ws, uint64_t requestId,
653 const json &params) {
654 auto chIdIt = params.find("channel_id");
655 if (chIdIt == params.end() || !chIdIt->is_number_unsigned()) {
656 sendError(ws, requestId, "protocol_error",
657 "subscribe requires unsigned \"channel_id\"");
658 return;
659 }
660 uint64_t channelId = chIdIt->get<uint64_t>();
661
662 {
663 std::lock_guard<std::mutex> chLock(channelsMutex);
664 auto it = channelById.find(channelId);
665 if (it == channelById.end()) {
666 sendError(ws, requestId, "unknown_channel",
667 std::format("No channel with id {}", channelId));
668 return;
669 }
670 if (it->second.direction != ChannelDirection::ToClient) {
671 sendError(
672 ws, requestId, "wrong_direction",
673 std::format("Channel id {} is not a to-client channel", channelId));
674 return;
675 }
676
677 std::lock_guard<std::mutex> sLock(sessionMutex);
678 if (!session) {
679 sendError(ws, requestId, "internal", "No active session");
680 return;
681 }
682 std::lock_guard<std::mutex> subLock(session->subscribedMutex);
683 session->subscribed.insert(channelId);
684 }
685
686 // Send the subscribe-ack BEFORE waking the transport thread. IXWebSocket
687 // queues sends in FIFO order on a given connection, so as long as the ack
688 // is enqueued first, any data frames the transport thread emits next will
689 // arrive after it. This spares clients from having to tolerate data on a
690 // channel before they've seen the ack confirming the subscription.
691 sendResult(ws, requestId, json::object());
692
693 // Now kick the transport thread: if the port already has queued data
694 // (typical for accelerator-startup writes that landed before the client
695 // subscribed), this is what flushes it; if the queue is empty, the
696 // transport thread will just see nothing to drain and go back to sleep.
697 // The dirty-set semantics of `readyIds` dedupe any concurrent doorbell
698 // from a DPI write.
699 readyIds.markReady(channelId);
700}
701
702void Impl::handleUnsubscribe(ix::WebSocket &ws, uint64_t requestId,
703 const json &params) {
704 auto chIdIt = params.find("channel_id");
705 if (chIdIt == params.end() || !chIdIt->is_number_unsigned()) {
706 sendError(ws, requestId, "protocol_error",
707 "unsubscribe requires unsigned \"channel_id\"");
708 return;
709 }
710 uint64_t channelId = chIdIt->get<uint64_t>();
711
712 std::lock_guard<std::mutex> sLock(sessionMutex);
713 if (!session) {
714 sendError(ws, requestId, "internal", "No active session");
715 return;
716 }
717 std::lock_guard<std::mutex> subLock(session->subscribedMutex);
718 auto removed = session->subscribed.erase(channelId);
719 if (!removed) {
720 sendError(ws, requestId, "not_subscribed",
721 std::format("Channel id {} is not subscribed", channelId));
722 return;
723 }
724 sendResult(ws, requestId, json::object());
725}
726
727void Impl::sendResult(ix::WebSocket &ws, uint64_t requestId,
728 const json &result) {
729 json resp;
730 resp["type"] = "response";
731 resp["request_id"] = requestId;
732 resp["result"] = result;
733 // IXWebSocket serializes concurrent send calls internally; we don't need
734 // sessionMutex here, and taking it would deadlock with handlers that hold
735 // sessionMutex while replying (e.g. handleUnsubscribe).
736 ws.sendUtf8Text(resp.dump());
737}
738
739void Impl::sendError(ix::WebSocket &ws, uint64_t requestId,
740 const std::string &code, const std::string &message) {
741 json resp;
742 resp["type"] = "response";
743 resp["request_id"] = requestId;
744 resp["error"] = {{"code", code}, {"message", message}};
745 ws.sendUtf8Text(resp.dump());
746}
747
748//===----------------------------------------------------------------------===//
749// Impl - transport thread (push model, `to_client` only)
750//===----------------------------------------------------------------------===//
751
753 while (true) {
754 std::unordered_set<uint64_t> ids;
755 if (!readyIds.waitDrain(ids))
756 return;
757 for (uint64_t id : ids) {
758 RpcServerWritePort *writePort = nullptr;
759 {
760 std::lock_guard<std::mutex> chLock(channelsMutex);
761 auto it = channelById.find(id);
762 if (it == channelById.end() ||
763 it->second.direction != ChannelDirection::ToClient)
764 continue;
765 writePort = it->second.writePort;
766 }
767 if (!writePort)
768 continue;
769
770 // Only drain if a client is subscribed; otherwise data stays in the
771 // queue until subscription (handleSubscribe fires markReady). We
772 // re-acquire `sessionMutex` + `subscribedMutex` on every iteration so
773 // `handleUnsubscribe` can interpose between frames: once it erases the
774 // subscription and sends the unsubscribe-ack, the next iteration here
775 // sees the channel as not subscribed and breaks, guaranteeing the ack
776 // is the last `to_client` frame on that channel (as the protocol spec
777 // requires). Holding the locks across the actual `sendBinary` also
778 // keeps `session->ws` valid against a concurrent `onClose()`.
779 while (true) {
780 if (readyIds.isShutdown())
781 return;
782 std::lock_guard<std::mutex> sLock(sessionMutex);
783 if (!session)
784 break;
785 std::lock_guard<std::mutex> subLock(session->subscribedMutex);
786 if (!session->subscribed.count(id))
787 break;
788 std::optional<MessageData> msg = writePort->writeQueue.pop();
789 if (!msg)
790 break;
791 std::string frame = buildDataFrame(id, msg->getBytes(), msg->getSize());
792 // IXWebSocket serializes concurrent send calls internally and queues
793 // them in FIFO order on the WS; no extra mutex is needed for that,
794 // but we keep `sessionMutex` to guard the `ws` pointer's lifetime.
795 session->ws->sendBinary(frame);
796 }
797 }
798 }
799}
800
801//===----------------------------------------------------------------------===//
802// RpcServer pass-throughs
803//===----------------------------------------------------------------------===//
804
806RpcServer::~RpcServer() = default;
807
808void RpcServer::setManifest(int esiVersion,
809 const std::vector<uint8_t> &compressedManifest) {
810 if (!impl)
811 throw std::runtime_error("Server not running");
812 impl->setManifest(esiVersion, compressedManifest);
813}
814
816 const std::string &type) {
817 if (!impl)
818 throw std::runtime_error("Server not running");
819 return impl->registerReadPort(name, type);
820}
821
823 const std::string &type) {
824 if (!impl)
825 throw std::runtime_error("Server not running");
826 return impl->registerWritePort(name, type);
827}
828
829void RpcServer::run(int port) {
830 if (impl)
831 throw std::runtime_error("Server already running");
832 impl = std::make_unique<Impl>(ctxt, port);
833}
834
835void RpcServer::stop(uint32_t timeoutMS) {
836 if (!impl)
837 throw std::runtime_error("Server not running");
838 impl->stop(timeoutMS);
839}
840
842 if (!impl)
843 throw std::runtime_error("Server not running");
844 return impl->getPort();
845}
AcceleratorConnections, Accelerators, and Manifests must all share a context.
Definition Context.h:34
Logger & getLogger()
Definition Context.h:69
virtual void error(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an error.
Definition Logging.h:64
virtual void warning(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report a warning.
Definition Logging.h:70
virtual void info(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an informational message.
Definition Logging.h:75
void debug(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report a debug message.
Definition Logging.h:83
A concrete flat message backed by a single vector of bytes.
Definition Common.h:155
size_t getSize() const
Get the size of the data in bytes.
Definition Common.h:180
A ChannelPort which reads data from the accelerator.
Definition Ports.h:453
virtual void connect(ReadCallback callback, const ConnectOptions &options={})
Definition Ports.cpp:140
bool invokeCallback(std::unique_ptr< SegmentedMessageData > &msg)
Invoke the currently registered callback.
Definition Ports.cpp:87
ReadChannelPort(const Type *type)
Definition Ports.h:468
Root class of the ESI type system.
Definition Types.h:36
A ChannelPort which sends data to the accelerator.
Definition Ports.h:308
virtual bool tryWriteImpl(const MessageData &data)=0
Implementation for tryWrite(). Subclasses must implement this.
virtual void writeImpl(const MessageData &)=0
Implementation for write(). Subclasses must implement this.
ChannelDirection
Channel direction as reported by the server.
Definition RpcClient.h:50
Cross-thread error channel for the IXWebSocket network thread.
Definition RpcWire.h:69
void record(std::exception_ptr ep) noexcept
Stash the first exception we see (subsequent ones are dropped – the first is the most informative).
Definition RpcWire.h:73
void check()
Consume and rethrow the stored fault on the caller's thread, if any.
Definition RpcWire.h:80
ReadChannelPort & registerReadPort(const std::string &name, const std::string &type)
std::unique_ptr< ix::WebSocketServer > server
std::vector< uint8_t > compressedManifest
void handleUnsubscribe(ix::WebSocket &ws, uint64_t requestId, const json &params)
std::unordered_map< std::string, uint64_t > idByName
void sendResult(ix::WebSocket &ws, uint64_t requestId, const json &result)
void handleHello(ix::WebSocket &ws, uint64_t requestId, const json &params)
::esi::cosim::FaultStash faultStash
void handleControlFrame(ix::WebSocket &ws, const std::string &text)
void sendError(ix::WebSocket &ws, uint64_t requestId, const std::string &code, const std::string &message)
Impl(Context &ctxt, int port)
void onOpen(ix::WebSocket &ws)
void stop(uint32_t timeoutMS)
std::condition_variable manifestReadyCV
std::map< std::string, std::unique_ptr< RpcServerWritePort > > writePorts
WriteChannelPort & registerWritePort(const std::string &name, const std::string &type)
std::unordered_map< uint64_t, ChannelInfo > channelById
std::map< std::string, std::unique_ptr< RpcServerReadPort > > readPorts
void handleSubscribe(ix::WebSocket &ws, uint64_t requestId, const json &params)
utils::ReadyIdSet< uint64_t > readyIds
void onClientMessage(std::shared_ptr< ix::ConnectionState > state, ix::WebSocket &ws, const ix::WebSocketMessagePtr &msg)
void setManifest(int esiVersion, const std::vector< uint8_t > &compressedManifest)
void handleBinaryFrame(const std::string &data)
std::unique_ptr< ClientSession > session
RpcServer(Context &ctxt)
ReadChannelPort & registerReadPort(const std::string &name, const std::string &type)
Register a read or write port which communicates over RPC.
void run(int port=-1)
void stop(uint32_t timeoutMS=0)
void setManifest(int esiVersion, const std::vector< uint8_t > &compressedManifest)
Set the manifest and version.
std::unique_ptr< Impl > impl
Definition RpcServer.h:65
WriteChannelPort & registerWritePort(const std::string &name, const std::string &type)
Multi-producer / single-consumer dirty-set of channel ids, with CV-style blocking drain semantics.
Definition Utils.h:177
void requestShutdown()
Signal a clean shutdown: wakes every current and future waitDrain caller, which will then observe fal...
Definition Utils.h:210
bool isShutdown() const
True once requestShutdown() has been called.
Definition Utils.h:219
bool waitDrain(std::unordered_set< ID > &out, std::optional< std::chrono::milliseconds > backoff={})
Block until either requestShutdown() is called or the set is non-empty, then atomically swap the curr...
Definition Utils.h:196
void markReady(ID id)
Add id to the dirty set and wake the consumer (if any).
Definition Utils.h:181
Thread safe queue.
Definition Utils.h:48
void push(E... t)
Push onto the queue.
Definition Utils.h:86
std::string buildDataFrame(uint64_t channelId, const uint8_t *bytes, size_t size)
Pack a cosim binary data frame: [u64 LE channel_id][payload].
Definition RpcWire.h:29
bool parseDataFrame(const std::string &data, uint64_t &channelId, const uint8_t *&payload, size_t &payloadSize)
Parse a cosim binary data frame.
Definition RpcWire.h:48
Definition esi.py:1
std::unordered_set< uint64_t > subscribed