21#include <ixwebsocket/IXBase64.h>
22#include <ixwebsocket/IXNetSystem.h>
23#include <ixwebsocket/IXWebSocket.h>
24#include <ixwebsocket/IXWebSocketServer.h>
25#include <nlohmann/json.hpp>
29#include <condition_variable>
39#include <unordered_map>
40#include <unordered_set>
48#include <netinet/in.h>
49#include <sys/socket.h>
55using json = nlohmann::json;
66static void writePortFile(uint16_t port) {
67 FILE *fd = fopen(
"cosim.cfg",
"w");
70 fprintf(fd,
"port: %u\n",
static_cast<unsigned int>(port));
79static int pickEphemeralPort() {
81 SOCKET fd = socket(AF_INET, SOCK_STREAM, 0);
82 if (fd == INVALID_SOCKET)
85 int fd = socket(AF_INET, SOCK_STREAM, 0);
90 setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
91 reinterpret_cast<const char *
>(&enable),
sizeof(enable));
93 addr.sin_family = AF_INET;
94 addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
96 if (bind(fd,
reinterpret_cast<sockaddr *
>(&addr),
sizeof(addr)) < 0) {
105 socklen_t len =
sizeof(bound);
106 int rc = getsockname(fd,
reinterpret_cast<sockaddr *
>(&bound), &len);
114 return ntohs(bound.sin_port);
117class RpcServerReadPort;
118class RpcServerWritePort;
137 const std::string &type);
139 const std::string &type);
141 void stop(uint32_t timeoutMS);
186 std::map<std::string, std::unique_ptr<RpcServerReadPort>>
readPorts;
187 std::map<std::string, std::unique_ptr<RpcServerWritePort>>
writePorts;
189 std::unordered_map<std::string, uint64_t>
idByName;
200 std::unique_ptr<ix::WebSocketServer>
server;
205 ix::WebSocket &ws,
const ix::WebSocketMessagePtr &msg);
210 void handleHello(ix::WebSocket &ws, uint64_t requestId,
const json ¶ms);
215 void sendResult(ix::WebSocket &ws, uint64_t requestId,
const json &result);
216 void sendError(ix::WebSocket &ws, uint64_t requestId,
const std::string &code,
217 const std::string &message);
255 void connect(
const ConnectOptions &options = {})
override {
256 ConnectOptions forced = options;
257 forced.bufferSize = 0;
265 std::unique_ptr<SegmentedMessageData> msg =
266 std::make_unique<MessageData>(data);
277 RpcServerWritePort(
Type *type,
Impl &impl, uint64_t channelId)
282 uint64_t getChannelId()
const {
return channelId; }
304 if (!ix::initNetSystem())
305 throw std::runtime_error(
306 "RpcServer: ix::initNetSystem() failed (WSAStartup)");
310 int requestedPort = port;
311 if (requestedPort <= 0) {
312 requestedPort = pickEphemeralPort();
313 if (requestedPort <= 0)
314 throw std::runtime_error(
315 "RpcServer: failed to obtain an ephemeral TCP port");
318 const std::string host =
"127.0.0.1";
319 server = std::make_unique<ix::WebSocketServer>(requestedPort, host);
320 server->disablePerMessageDeflate();
322 server->setOnClientMessageCallback(
323 [
this](std::shared_ptr<ix::ConnectionState> state, ix::WebSocket &ws,
324 const ix::WebSocketMessagePtr &msg) {
328 auto res =
server->listen();
330 throw std::runtime_error(
"RpcServer: listen failed: " + res.second);
334 writePortFile(
static_cast<uint16_t
>(
boundPort));
353 }
catch (
const std::exception &e) {
356 std::string(
"Suppressed exception during ~RpcServer::Impl: ") +
360 "cosim",
"Suppressed non-std::exception during ~RpcServer::Impl");
366 const std::vector<uint8_t> &compressedManifest) {
378 const std::string &type) {
382 auto port = std::make_unique<RpcServerReadPort>(
new Type(type));
383 RpcServerReadPort *raw = port.get();
384 readPorts.emplace(name, std::move(port));
386 ChannelInfo info{id, name, type, ChannelDirection::ToServer, raw,
nullptr};
393 const std::string &type) {
397 auto port = std::make_unique<RpcServerWritePort>(
new Type(type), *
this,
id);
398 RpcServerWritePort *raw = port.get();
401 ChannelInfo info{id, name, type, ChannelDirection::ToClient,
nullptr, raw};
447 const ix::WebSocketMessagePtr &msg) {
449 case ix::WebSocketMessageType::Open: {
457 "cosim",
"Rejecting additional client; one already connected");
459 err[
"type"] =
"error";
460 err[
"error"] = {{
"code",
"server_busy"},
461 {
"message",
"cosim server allows only one client at a "
462 "time; another client is already connected"}};
463 ws.sendUtf8Text(err.dump());
465 ws.close(1013,
"cosim server busy: another client is already connected");
468 session = std::make_unique<ClientSession>();
473 case ix::WebSocketMessageType::Close:
474 case ix::WebSocketMessageType::Error: {
476 "cosim", std::format(
"Client disconnected: {}",
477 msg->type == ix::WebSocketMessageType::Error
478 ? msg->errorInfo.reason
479 : msg->closeInfo.reason));
483 case ix::WebSocketMessageType::Message:
508 const uint8_t *payloadBytes;
510 if (!
parseDataFrame(data, channelId, payloadBytes, payloadSize)) {
512 "Received binary frame shorter than 8-byte header");
516 RpcServerReadPort *port =
nullptr;
521 it->second.direction != ChannelDirection::ToServer)
524 port = it->second.readPort;
528 "cosim", std::format(
"Binary frame for unknown to-server channel id {}",
538 if (!port->deliver(payload))
541 std::format(
"Dropped {} bytes for channel id {}: port not connected",
542 payload.
getSize(), channelId));
548 req = json::parse(text);
549 }
catch (
const std::exception &e) {
551 "cosim", std::format(
"Failed to parse control frame: {}", e.what()));
553 std::string(
"Failed to parse JSON: ") + e.what());
558 auto typeIt = req.find(
"type");
559 if (typeIt == req.end() || !typeIt->is_string() ||
560 typeIt->get<std::string>() !=
"request") {
562 "Control frame missing \"type\":\"request\"");
565 uint64_t requestId = 0;
566 if (
auto idIt = req.find(
"request_id"); idIt != req.end()) {
568 if (!idIt->is_number_unsigned()) {
570 "\"request_id\" must be an unsigned integer");
574 requestId = idIt->get<uint64_t>();
575 }
catch (
const std::exception &e) {
577 std::string(
"Invalid \"request_id\": ") + e.what());
581 auto methodIt = req.find(
"method");
582 if (methodIt == req.end() || !methodIt->is_string()) {
583 sendError(ws, requestId,
"protocol_error",
"Missing \"method\"");
586 std::string method = methodIt->get<std::string>();
587 json params = req.value(
"params", json::object());
589 if (method ==
"hello")
591 else if (method ==
"subscribe")
593 else if (method ==
"unsubscribe")
596 sendError(ws, requestId,
"protocol_error",
"Unknown method: " + method);
613 result[
"protocol_version"] = 3;
617 result[
"compressed_manifest_b64"] = macaron::Base64::Encode(
622 json channelsJson = json::array();
632 c[
"channel_id"] = info.id;
633 c[
"name"] = info.name;
634 c[
"type"] = info.typeId;
635 c[
"direction"] = info.direction == ChannelDirection::ToServer
638 channelsJson.push_back(std::move(c));
641 result[
"channels"] = std::move(channelsJson);
653 const json ¶ms) {
654 auto chIdIt = params.find(
"channel_id");
655 if (chIdIt == params.end() || !chIdIt->is_number_unsigned()) {
656 sendError(ws, requestId,
"protocol_error",
657 "subscribe requires unsigned \"channel_id\"");
660 uint64_t channelId = chIdIt->get<uint64_t>();
666 sendError(ws, requestId,
"unknown_channel",
667 std::format(
"No channel with id {}", channelId));
670 if (it->second.direction != ChannelDirection::ToClient) {
672 ws, requestId,
"wrong_direction",
673 std::format(
"Channel id {} is not a to-client channel", channelId));
679 sendError(ws, requestId,
"internal",
"No active session");
682 std::lock_guard<std::mutex> subLock(
session->subscribedMutex);
683 session->subscribed.insert(channelId);
703 const json ¶ms) {
704 auto chIdIt = params.find(
"channel_id");
705 if (chIdIt == params.end() || !chIdIt->is_number_unsigned()) {
706 sendError(ws, requestId,
"protocol_error",
707 "unsubscribe requires unsigned \"channel_id\"");
710 uint64_t channelId = chIdIt->get<uint64_t>();
714 sendError(ws, requestId,
"internal",
"No active session");
717 std::lock_guard<std::mutex> subLock(
session->subscribedMutex);
718 auto removed =
session->subscribed.erase(channelId);
720 sendError(ws, requestId,
"not_subscribed",
721 std::format(
"Channel id {} is not subscribed", channelId));
728 const json &result) {
730 resp[
"type"] =
"response";
731 resp[
"request_id"] = requestId;
732 resp[
"result"] = result;
736 ws.sendUtf8Text(resp.dump());
740 const std::string &code,
const std::string &message) {
742 resp[
"type"] =
"response";
743 resp[
"request_id"] = requestId;
744 resp[
"error"] = {{
"code", code}, {
"message", message}};
745 ws.sendUtf8Text(resp.dump());
754 std::unordered_set<uint64_t> ids;
757 for (uint64_t
id : ids) {
758 RpcServerWritePort *writePort =
nullptr;
763 it->second.direction != ChannelDirection::ToClient)
765 writePort = it->second.writePort;
785 std::lock_guard<std::mutex> subLock(
session->subscribedMutex);
786 if (!
session->subscribed.count(
id))
788 std::optional<MessageData> msg = writePort->writeQueue.pop();
791 std::string frame =
buildDataFrame(
id, msg->getBytes(), msg->getSize());
795 session->ws->sendBinary(frame);
809 const std::vector<uint8_t> &compressedManifest) {
811 throw std::runtime_error(
"Server not running");
812 impl->setManifest(esiVersion, compressedManifest);
816 const std::string &type) {
818 throw std::runtime_error(
"Server not running");
819 return impl->registerReadPort(name, type);
823 const std::string &type) {
825 throw std::runtime_error(
"Server not running");
826 return impl->registerWritePort(name, type);
831 throw std::runtime_error(
"Server already running");
832 impl = std::make_unique<Impl>(
ctxt, port);
837 throw std::runtime_error(
"Server not running");
838 impl->stop(timeoutMS);
843 throw std::runtime_error(
"Server not running");
844 return impl->getPort();
AcceleratorConnections, Accelerators, and Manifests must all share a context.
virtual void error(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an error.
virtual void warning(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report a warning.
virtual void info(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an informational message.
void debug(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report a debug message.
A concrete flat message backed by a single vector of bytes.
size_t getSize() const
Get the size of the data in bytes.
A ChannelPort which reads data from the accelerator.
virtual void connect(ReadCallback callback, const ConnectOptions &options={})
bool invokeCallback(std::unique_ptr< SegmentedMessageData > &msg)
Invoke the currently registered callback.
ReadChannelPort(const Type *type)
Root class of the ESI type system.
A ChannelPort which sends data to the accelerator.
virtual bool tryWriteImpl(const MessageData &data)=0
Implementation for tryWrite(). Subclasses must implement this.
virtual void writeImpl(const MessageData &)=0
Implementation for write(). Subclasses must implement this.
ChannelDirection
Channel direction as reported by the server.
Cross-thread error channel for the IXWebSocket network thread.
void record(std::exception_ptr ep) noexcept
Stash the first exception we see (subsequent ones are dropped – the first is the most informative).
void check()
Consume and rethrow the stored fault on the caller's thread, if any.
ReadChannelPort & registerReadPort(const std::string &name, const std::string &type)
std::unique_ptr< ix::WebSocketServer > server
std::vector< uint8_t > compressedManifest
void handleUnsubscribe(ix::WebSocket &ws, uint64_t requestId, const json ¶ms)
std::unordered_map< std::string, uint64_t > idByName
void sendResult(ix::WebSocket &ws, uint64_t requestId, const json &result)
void handleHello(ix::WebSocket &ws, uint64_t requestId, const json ¶ms)
::esi::cosim::FaultStash faultStash
void handleControlFrame(ix::WebSocket &ws, const std::string &text)
void sendError(ix::WebSocket &ws, uint64_t requestId, const std::string &code, const std::string &message)
Impl(Context &ctxt, int port)
void onOpen(ix::WebSocket &ws)
void stop(uint32_t timeoutMS)
std::condition_variable manifestReadyCV
std::map< std::string, std::unique_ptr< RpcServerWritePort > > writePorts
WriteChannelPort & registerWritePort(const std::string &name, const std::string &type)
std::unordered_map< uint64_t, ChannelInfo > channelById
std::map< std::string, std::unique_ptr< RpcServerReadPort > > readPorts
std::thread transportThread
void handleSubscribe(ix::WebSocket &ws, uint64_t requestId, const json ¶ms)
utils::ReadyIdSet< uint64_t > readyIds
void onClientMessage(std::shared_ptr< ix::ConnectionState > state, ix::WebSocket &ws, const ix::WebSocketMessagePtr &msg)
void setManifest(int esiVersion, const std::vector< uint8_t > &compressedManifest)
void handleBinaryFrame(const std::string &data)
std::unique_ptr< ClientSession > session
ReadChannelPort & registerReadPort(const std::string &name, const std::string &type)
Register a read or write port which communicates over RPC.
void stop(uint32_t timeoutMS=0)
void setManifest(int esiVersion, const std::vector< uint8_t > &compressedManifest)
Set the manifest and version.
std::unique_ptr< Impl > impl
WriteChannelPort & registerWritePort(const std::string &name, const std::string &type)
Multi-producer / single-consumer dirty-set of channel ids, with CV-style blocking drain semantics.
void requestShutdown()
Signal a clean shutdown: wakes every current and future waitDrain caller, which will then observe fal...
bool isShutdown() const
True once requestShutdown() has been called.
bool waitDrain(std::unordered_set< ID > &out, std::optional< std::chrono::milliseconds > backoff={})
Block until either requestShutdown() is called or the set is non-empty, then atomically swap the curr...
void markReady(ID id)
Add id to the dirty set and wake the consumer (if any).
void push(E... t)
Push onto the queue.
std::string buildDataFrame(uint64_t channelId, const uint8_t *bytes, size_t size)
Pack a cosim binary data frame: [u64 LE channel_id][payload].
bool parseDataFrame(const std::string &data, uint64_t &channelId, const uint8_t *&payload, size_t &payloadSize)
Parse a cosim binary data frame.
RpcServerReadPort * readPort
ChannelDirection direction
RpcServerWritePort * writePort
std::mutex subscribedMutex
std::unordered_set< uint64_t > subscribed