CIRCT 23.0.0git
Loading...
Searching...
No Matches
Accelerator.cpp
Go to the documentation of this file.
1//===- Accelerator.cpp - ESI accelerator system API -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// DO NOT EDIT!
10// This file is distributed as part of an ESI package. The source for this file
11// should always be modified within CIRCT (lib/dialect/ESI/runtime/cpp/).
12//
13//===----------------------------------------------------------------------===//
14
15#include "esi/Accelerator.h"
16
17#include <cassert>
18#include <cstdlib>
19#include <filesystem>
20#include <map>
21#include <sstream>
22#include <stdexcept>
23#include <vector>
24
25#include <iostream>
26
27#ifdef __linux__
28#include <dlfcn.h>
29#include <linux/limits.h>
30#include <unistd.h>
31#elif _WIN32
32#include <windows.h>
33#endif
34
35using namespace esi;
36using namespace esi::services;
37
38namespace esi {
40 : ctxt(ctxt), serviceThread(nullptr) {}
42
44 if (!serviceThread)
45 serviceThread = std::make_unique<AcceleratorServiceThread>();
46 return serviceThread.get();
47}
48void AcceleratorConnection::createEngine(const std::string &engineTypeName,
49 AppIDPath idPath,
50 const ServiceImplDetails &details,
51 const HWClientDetails &clients) {
52 std::unique_ptr<Engine> engine = ::esi::registry::createEngine(
53 *this, engineTypeName, idPath, details, clients);
54 registerEngine(idPath, std::move(engine), clients);
55}
56
58 std::unique_ptr<Engine> engine,
59 const HWClientDetails &clients) {
60 assert(engine);
61 auto [engineIter, _] = ownedEngines.emplace(idPath, std::move(engine));
62
63 // Engine is now owned by the accelerator connection, so the std::unique_ptr
64 // is no longer valid. Resolve a new one from the map iter.
65 Engine *enginePtr = engineIter->second.get();
66 // Compute our parents idPath path.
67 AppIDPath prefix = std::move(idPath);
68 if (prefix.size() > 0)
69 prefix.pop_back();
70
71 for (const auto &client : clients) {
72 AppIDPath fullClientPath = prefix + client.relPath;
73 for (const auto &channel : client.channelAssignments)
74 clientEngines[fullClientPath].setEngine(channel.first, enginePtr);
75 }
76}
77
79 AppIDPath id,
80 std::string implName,
81 ServiceImplDetails details,
82 HWClientDetails clients) {
83 std::unique_ptr<Service> &cacheEntry =
84 serviceCache[make_tuple(std::string(svcType.name()), id)];
85 if (cacheEntry == nullptr) {
86 Service *svc = createService(svcType, id, implName, details, clients);
87 if (!svc)
88 svc = ServiceRegistry::createService(this, svcType, id, implName, details,
89 clients);
90 if (!svc)
91 return nullptr;
92 cacheEntry = std::unique_ptr<Service>(svc);
93 }
94 return cacheEntry.get();
95}
96
98AcceleratorConnection::takeOwnership(std::unique_ptr<Accelerator> acc) {
100 throw std::runtime_error(
101 "AcceleratorConnection already owns an accelerator");
102 ownedAccelerator = std::move(acc);
103 return ownedAccelerator.get();
104}
105
106/// Get the path to the currently running executable.
107static std::filesystem::path getExePath() {
108#ifdef __linux__
109 char result[PATH_MAX];
110 ssize_t count = readlink("/proc/self/exe", result, PATH_MAX);
111 if (count == -1)
112 throw std::runtime_error("Could not get executable path");
113 return std::filesystem::path(std::string(result, count));
114#elif _WIN32
115 char buffer[MAX_PATH];
116 DWORD length = GetModuleFileNameA(NULL, buffer, MAX_PATH);
117 if (length == 0)
118 throw std::runtime_error("Could not get executable path");
119 return std::filesystem::path(std::string(buffer, length));
120#else
121#eror "Unsupported platform"
122#endif
123}
124
125/// Get the path to the currently running shared library.
126static std::filesystem::path getLibPath() {
127#ifdef __linux__
128 Dl_info dl_info;
129 dladdr((void *)getLibPath, &dl_info);
130 return std::filesystem::path(std::string(dl_info.dli_fname));
131#elif _WIN32
132 HMODULE hModule = NULL;
133 if (!GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
134 GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
135 reinterpret_cast<LPCSTR>(&getLibPath), &hModule)) {
136 // Handle error
137 return std::filesystem::path();
138 }
139
140 char buffer[MAX_PATH];
141 DWORD length = GetModuleFileNameA(hModule, buffer, MAX_PATH);
142 if (length == 0)
143 throw std::runtime_error("Could not get library path");
144
145 return std::filesystem::path(std::string(buffer, length));
146#else
147#eror "Unsupported platform"
148#endif
149}
150
151/// Get the list of directories to search for backend plugins.
152static std::vector<std::filesystem::path> getESIBackendDirectories() {
153 std::vector<std::filesystem::path> directories;
154
155 // First, check current directory.
156 directories.push_back(std::filesystem::current_path());
157
158 // Next, parse the ESI_BACKENDS environment variable and add those.
159 const char *esiBackends = std::getenv("ESI_BACKENDS");
160 if (esiBackends) {
161 // Use platform-specific path separator
162#ifdef _WIN32
163 const char separator = ';';
164#else
165 const char separator = ':';
166#endif
167
168 std::string pathsStr(esiBackends);
169 std::stringstream ss(pathsStr);
170 std::string path;
171
172 while (std::getline(ss, path, separator))
173 if (!path.empty())
174 directories.emplace_back(path);
175 }
176
177 // Next, try the directory of the executable.
178 directories.push_back(getExePath().parent_path());
179 // Finally, try the directory of the library.
180 directories.push_back(getLibPath().parent_path());
181
182 return directories;
183}
184
185/// Load a backend plugin dynamically. Plugins are expected to be named
186/// lib<BackendName>Backend.so and located in one of 1) CWD, 2) directories
187/// specified in ESI_BACKENDS environment variable, 3) in the same directory as
188/// the application, or 4) in the same directory as this library.
189static void loadBackend(Context &ctxt, std::string backend) {
190 Logger &logger = ctxt.getLogger();
191 backend[0] = toupper(backend[0]);
192
193 // Get the file name we are looking for.
194#ifdef __linux__
195 std::string backendFileName = "lib" + backend + "Backend.so";
196#elif _WIN32
197 // In MSVC debug builds, load the debug variant of the plugin DLL (e.g.
198 // CosimBackend_d.dll) to ensure compatibility.
199#if defined(_MSC_VER) && defined(_DEBUG)
200 std::string backendFileName = backend + "Backend_d.dll";
201#else
202 std::string backendFileName = backend + "Backend.dll";
203#endif
204#else
205#error "Unsupported platform"
206#endif
207
208 // First, try the current directory.
209 std::filesystem::path backendPath;
210 // Next, try directories specified in ESI_BACKENDS environment variable.
211 std::vector<std::filesystem::path> esiBackendDirs =
213 bool found = false;
214 for (const auto &dir : esiBackendDirs) {
215 backendPath = dir / backendFileName;
216 logger.debug("CONNECT",
217 "trying to find backend plugin: " + backendPath.string());
218 if (std::filesystem::exists(backendPath)) {
219 found = true;
220 break;
221 }
222 }
223
224 // If the path was found, convert it to a string.
225 if (found) {
226 backendPath = std::filesystem::absolute(backendPath);
227 logger.debug("CONNECT", "found backend plugin: " + backendPath.string());
228 } else {
229 // If all else fails, just try the name.
230 backendPath = backendFileName;
231 logger.debug("CONNECT",
232 "trying to find backend plugin: " + backendFileName);
233 }
234
235 // Attempt to load it.
236#ifdef __linux__
237 void *handle = dlopen(backendPath.string().c_str(), RTLD_NOW | RTLD_GLOBAL);
238 if (!handle) {
239 std::string error(dlerror());
240 logger.error("CONNECT",
241 "while attempting to load backend plugin: " + error);
242 throw std::runtime_error("While attempting to load backend plugin: " +
243 error);
244 }
245#elif _WIN32
246 // Set the DLL directory to the same directory as the backend DLL in case it
247 // has transitive dependencies.
248 if (found) {
249 std::filesystem::path backendPathParent = backendPath.parent_path();
250 // If backendPath has no parent directory (e.g., it's a relative path or
251 // a filename without a directory), fallback to the current working
252 // directory. This ensures a valid directory is used for setting the DLL
253 // search path.
254 if (backendPathParent.empty())
255 backendPathParent = std::filesystem::current_path();
256 logger.debug("CONNECT", "setting DLL search directory to: " +
257 backendPathParent.string());
258 if (SetDllDirectoryA(backendPathParent.string().c_str()) == 0)
259 throw std::runtime_error("While setting DLL directory: " +
260 std::to_string(GetLastError()));
261 }
262
263 // Load the backend plugin.
264 HMODULE handle = LoadLibraryA(backendPath.string().c_str());
265 if (!handle) {
266 DWORD error = GetLastError();
267 // Get the error message string
268 LPSTR messageBuffer = nullptr;
269 size_t size = FormatMessageA(
270 FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
271 FORMAT_MESSAGE_IGNORE_INSERTS,
272 nullptr, error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
273 (LPSTR)&messageBuffer, 0, nullptr);
274
275 std::string errorMessage;
276 if (size > 0 && messageBuffer != nullptr) {
277 errorMessage = std::string(messageBuffer, size);
278 LocalFree(messageBuffer);
279 } else {
280 errorMessage = "Unknown error";
281 }
282
283 std::string fullError = "While attempting to load backend plugin '" +
284 backendPath.string() + "': " + errorMessage +
285 " (error code: " + std::to_string(error) + ")";
286
287 logger.error("CONNECT", fullError);
288 throw std::runtime_error(fullError);
289 }
290#else
291#eror "Unsupported platform"
292#endif
293 logger.info("CONNECT", "loaded backend plugin: " + backendPath.string());
294}
295
296namespace registry {
297namespace internal {
298
300public:
301 static std::map<std::string, BackendCreate> &get() {
302 static BackendRegistry instance;
303 return instance.backendRegistry;
304 }
305
306private:
307 std::map<std::string, BackendCreate> backendRegistry;
308};
309
310void registerBackend(const std::string &name, BackendCreate create) {
311 auto &registry = BackendRegistry::get();
312 if (registry.count(name))
313 throw std::runtime_error("Backend already exists in registry");
314 registry[name] = create;
315}
316} // namespace internal
317
318} // namespace registry
319
321 std::string connection) {
323 auto f = registry.find(backend);
324 if (f == registry.end()) {
325 // If it's not already found in the registry, try to load it dynamically.
326 loadBackend(*this, backend);
327 f = registry.find(backend);
328 if (f == registry.end()) {
329 ServiceImplDetails details;
330 details["backend"] = backend;
331 std::ostringstream loaded_backends;
332 bool first = true;
333 for (const auto &b : registry) {
334 if (!first)
335 loaded_backends << ", ";
336 loaded_backends << b.first;
337 first = false;
338 }
339 details["loaded_backends"] = loaded_backends.str();
340 getLogger().error("CONNECT", "backend '" + backend + "' not found",
341 &details);
342 throw std::runtime_error("Backend '" + backend + "' not found");
343 }
344 }
345 getLogger().info("CONNECT", "connecting to backend " + backend + " via '" +
346 connection + "'");
347 auto conn = f->second(*this, connection);
348 auto *connPtr = conn.get();
349 connections.emplace_back(std::move(conn));
350 return connPtr;
351}
352
354 Impl() {}
355 void start() { me = std::thread(&Impl::loop, this); }
356 void stop() {
357 shutdown = true;
358 me.join();
359 }
360 /// When there's data on any of the listenPorts, call the callback. This
361 /// method can be called from any thread.
362 void
363 addListener(std::initializer_list<ReadChannelPort *> listenPorts,
364 std::function<void(ReadChannelPort *, MessageData)> callback);
365
366 void addTask(std::function<void(void)> task) {
367 std::lock_guard<std::mutex> g(m);
368 taskList.push_back(task);
369 }
370
371private:
372 void loop();
373 volatile bool shutdown = false;
374 std::thread me;
375
376 // Protect the shared data structures.
377 std::mutex m;
378
379 // Map of read ports to callbacks.
380 std::map<ReadChannelPort *,
381 std::pair<std::function<void(ReadChannelPort *, MessageData)>,
382 std::future<MessageData>>>
384
385 /// Tasks which should be called on every loop iteration.
386 std::vector<std::function<void(void)>> taskList;
387};
388
389void AcceleratorServiceThread::Impl::loop() {
390 // These two variables should logically be in the loop, but this avoids
391 // reconstructing them on each iteration.
392 std::vector<std::tuple<ReadChannelPort *,
393 std::function<void(ReadChannelPort *, MessageData)>,
395 portUnlockWorkList;
396 std::vector<std::function<void(void)>> taskListCopy;
397 MessageData data;
398
399 while (!shutdown) {
400 // Ideally we'd have some wake notification here, but this sufficies for
401 // now.
402 // TODO: investigate better ways to do this. For now, just play nice with
403 // the other processes but don't waste time in between polling intervals.
404 std::this_thread::yield();
405
406 // Check and gather data from all the read ports we are monitoring. Put the
407 // callbacks to be called later so we can release the lock.
408 {
409 std::lock_guard<std::mutex> g(m);
410 for (auto &[channel, cbfPair] : listeners) {
411 assert(channel && "Null channel in listener list");
412 std::future<MessageData> &f = cbfPair.second;
413 if (f.wait_for(std::chrono::seconds(0)) == std::future_status::ready) {
414 portUnlockWorkList.emplace_back(channel, cbfPair.first, f.get());
415 f = channel->readAsync();
416 }
417 }
418 }
419
420 // Call the callbacks outside the lock.
421 for (auto [channel, cb, data] : portUnlockWorkList)
422 cb(channel, std::move(data));
423
424 // Clear the worklist for the next iteration.
425 portUnlockWorkList.clear();
426
427 // Call any tasks that have been added. Copy it first so we can release the
428 // lock ASAP.
429 {
430 std::lock_guard<std::mutex> g(m);
431 taskListCopy = taskList;
432 }
433 for (auto &task : taskListCopy)
434 task();
435 }
436}
437
438void AcceleratorServiceThread::Impl::addListener(
439 std::initializer_list<ReadChannelPort *> listenPorts,
440 std::function<void(ReadChannelPort *, MessageData)> callback) {
441 std::lock_guard<std::mutex> g(m);
442 for (auto port : listenPorts) {
443 if (listeners.count(port))
444 throw std::runtime_error("Port already has a listener");
445 listeners[port] = std::make_pair(callback, port->readAsync());
446 }
447}
448
450 : impl(std::make_unique<Impl>()) {
451 impl->start();
452}
454
456 if (impl) {
457 impl->stop();
458 impl.reset();
459 }
460}
461
462// When there's data on any of the listenPorts, call the callback. This is
463// kinda silly now that we have callback port support, especially given the
464// polling loop. Keep the functionality for now.
466 std::initializer_list<ReadChannelPort *> listenPorts,
467 std::function<void(ReadChannelPort *, MessageData)> callback) {
468 assert(impl && "Service thread not running");
469 impl->addListener(listenPorts, callback);
470}
471
473 assert(impl && "Service thread not running");
474 impl->addTask([&module]() { module.poll(); });
475}
476
478 if (serviceThread) {
479 serviceThread->stop();
480 serviceThread.reset();
481 }
482}
483
484} // namespace esi
assert(baseType &&"element must be base type")
Abstract class representing a connection to an accelerator.
Definition Accelerator.h:89
virtual Service * createService(Service::Type service, AppIDPath idPath, std::string implName, const ServiceImplDetails &details, const HWClientDetails &clients)=0
Called by getServiceImpl exclusively.
ServiceClass * getService(AppIDPath id={}, std::string implName={}, ServiceImplDetails details={}, HWClientDetails clients={})
Get a typed reference to a particular service type.
std::map< AppIDPath, BundleEngineMap > clientEngines
Mapping of clients to their servicing engines.
void registerEngine(AppIDPath idPath, std::unique_ptr< Engine > engine, const HWClientDetails &clients)
If createEngine is overridden, this method should be called to register the engine and all of the cha...
std::map< ServiceCacheKey, std::unique_ptr< Service > > serviceCache
std::unique_ptr< AcceleratorServiceThread > serviceThread
std::unique_ptr< Accelerator > ownedAccelerator
Accelerator object owned by this connection.
virtual void disconnect()
Disconnect from the accelerator cleanly.
std::map< AppIDPath, std::unique_ptr< Engine > > ownedEngines
Collection of owned engines.
virtual void createEngine(const std::string &engineTypeName, AppIDPath idPath, const ServiceImplDetails &details, const HWClientDetails &clients)
Create a new engine for channel communication with the accelerator.
AcceleratorServiceThread * getServiceThread()
Return a pointer to the accelerator 'service' thread (or threads).
AcceleratorConnection(Context &ctxt)
Accelerator * takeOwnership(std::unique_ptr< Accelerator > accel)
Assume ownership of an accelerator object.
Background thread which services various requests.
std::unique_ptr< Impl > impl
void stop()
Instruct the service thread to stop running.
void addPoll(HWModule &module)
Poll this module.
void addListener(std::initializer_list< ReadChannelPort * > listenPorts, std::function< void(ReadChannelPort *, MessageData)> callback)
When there's data on any of the listenPorts, call the callback.
Top level accelerator class.
Definition Accelerator.h:70
AcceleratorConnections, Accelerators, and Manifests must all share a context.
Definition Context.h:34
Logger & getLogger()
Definition Context.h:69
std::vector< std::unique_ptr< AcceleratorConnection > > connections
Definition Context.h:73
AcceleratorConnection * connect(std::string backend, std::string connection)
Connect to an accelerator backend.
Engines implement the actual channel communication between the host and the accelerator.
Definition Engines.h:42
Represents either the top level or an instance of a hardware module.
Definition Design.h:47
virtual void error(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an error.
Definition Logging.h:64
virtual void info(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an informational message.
Definition Logging.h:75
void debug(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report a debug message.
Definition Logging.h:83
A logical chunk of data representing serialized data.
Definition Common.h:113
A ChannelPort which reads data from the accelerator.
Definition Ports.h:318
std::map< std::string, BackendCreate > backendRegistry
static std::map< std::string, BackendCreate > & get()
static Service * createService(AcceleratorConnection *acc, Service::Type svcType, AppIDPath id, std::string implName, ServiceImplDetails details, HWClientDetails clients)
Create a service instance from the given details.
Definition Services.cpp:417
Parent class of all APIs modeled as 'services'.
Definition Services.h:59
const std::type_info & Type
Definition Services.h:61
void registerBackend(const std::string &name, BackendCreate create)
std::function< std::unique_ptr< AcceleratorConnection >(Context &, std::string)> BackendCreate
Backends can register themselves to be connected via a connection string.
std::unique_ptr< Engine > createEngine(AcceleratorConnection &conn, const std::string &dmaEngineName, AppIDPath idPath, const ServiceImplDetails &details, const HWClientDetails &clients)
Create an engine by name.
Definition Engines.cpp:509
Definition esi.py:1
static std::filesystem::path getExePath()
Get the path to the currently running executable.
std::map< std::string, std::any > ServiceImplDetails
Definition Common.h:108
static void loadBackend(Context &ctxt, std::string backend)
Load a backend plugin dynamically.
static std::filesystem::path getLibPath()
Get the path to the currently running shared library.
static std::vector< std::filesystem::path > getESIBackendDirectories()
Get the list of directories to search for backend plugins.
std::vector< HWClientDetail > HWClientDetails
Definition Common.h:107
std::map< ReadChannelPort *, std::pair< std::function< void(ReadChannelPort *, MessageData)>, std::future< MessageData > > > listeners
void addTask(std::function< void(void)> task)
void addListener(std::initializer_list< ReadChannelPort * > listenPorts, std::function< void(ReadChannelPort *, MessageData)> callback)
When there's data on any of the listenPorts, call the callback.
std::vector< std::function< void(void)> > taskList
Tasks which should be called on every loop iteration.