CIRCT 22.0.0git
Loading...
Searching...
No Matches
Accelerator.cpp
Go to the documentation of this file.
1//===- Accelerator.cpp - ESI accelerator system API -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// DO NOT EDIT!
10// This file is distributed as part of an ESI package. The source for this file
11// should always be modified within CIRCT (lib/dialect/ESI/runtime/cpp/).
12//
13//===----------------------------------------------------------------------===//
14
15#include "esi/Accelerator.h"
16
17#include <cassert>
18#include <cstdlib>
19#include <filesystem>
20#include <map>
21#include <sstream>
22#include <stdexcept>
23#include <vector>
24
25#include <iostream>
26
27#ifdef __linux__
28#include <dlfcn.h>
29#include <linux/limits.h>
30#include <unistd.h>
31#elif _WIN32
32#include <windows.h>
33#endif
34
35using namespace esi;
36using namespace esi::services;
37
38namespace esi {
40 : ctxt(ctxt), serviceThread(nullptr) {}
42
44 if (!serviceThread)
45 serviceThread = std::make_unique<AcceleratorServiceThread>();
46 return serviceThread.get();
47}
48void AcceleratorConnection::createEngine(const std::string &engineTypeName,
49 AppIDPath idPath,
50 const ServiceImplDetails &details,
51 const HWClientDetails &clients) {
52 std::unique_ptr<Engine> engine = ::esi::registry::createEngine(
53 *this, engineTypeName, idPath, details, clients);
54 registerEngine(idPath, std::move(engine), clients);
55}
56
58 std::unique_ptr<Engine> engine,
59 const HWClientDetails &clients) {
60 assert(engine);
61 auto [engineIter, _] = ownedEngines.emplace(idPath, std::move(engine));
62
63 // Engine is now owned by the accelerator connection, so the std::unique_ptr
64 // is no longer valid. Resolve a new one from the map iter.
65 Engine *enginePtr = engineIter->second.get();
66 // Compute our parents idPath path.
67 AppIDPath prefix = std::move(idPath);
68 if (prefix.size() > 0)
69 prefix.pop_back();
70
71 for (const auto &client : clients) {
72 AppIDPath fullClientPath = prefix + client.relPath;
73 for (const auto &channel : client.channelAssignments)
74 clientEngines[fullClientPath].setEngine(channel.first, enginePtr);
75 }
76}
77
79 AppIDPath id,
80 std::string implName,
81 ServiceImplDetails details,
82 HWClientDetails clients) {
83 std::unique_ptr<Service> &cacheEntry =
84 serviceCache[make_tuple(std::string(svcType.name()), id)];
85 if (cacheEntry == nullptr) {
86 Service *svc = createService(svcType, id, implName, details, clients);
87 if (!svc)
88 svc = ServiceRegistry::createService(this, svcType, id, implName, details,
89 clients);
90 if (!svc)
91 return nullptr;
92 cacheEntry = std::unique_ptr<Service>(svc);
93 }
94 return cacheEntry.get();
95}
96
98AcceleratorConnection::takeOwnership(std::unique_ptr<Accelerator> acc) {
100 throw std::runtime_error(
101 "AcceleratorConnection already owns an accelerator");
102 ownedAccelerator = std::move(acc);
103 return ownedAccelerator.get();
104}
105
106/// Get the path to the currently running executable.
107static std::filesystem::path getExePath() {
108#ifdef __linux__
109 char result[PATH_MAX];
110 ssize_t count = readlink("/proc/self/exe", result, PATH_MAX);
111 if (count == -1)
112 throw std::runtime_error("Could not get executable path");
113 return std::filesystem::path(std::string(result, count));
114#elif _WIN32
115 char buffer[MAX_PATH];
116 DWORD length = GetModuleFileNameA(NULL, buffer, MAX_PATH);
117 if (length == 0)
118 throw std::runtime_error("Could not get executable path");
119 return std::filesystem::path(std::string(buffer, length));
120#else
121#eror "Unsupported platform"
122#endif
123}
124
125/// Get the path to the currently running shared library.
126static std::filesystem::path getLibPath() {
127#ifdef __linux__
128 Dl_info dl_info;
129 dladdr((void *)getLibPath, &dl_info);
130 return std::filesystem::path(std::string(dl_info.dli_fname));
131#elif _WIN32
132 HMODULE hModule = NULL;
133 if (!GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
134 GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
135 reinterpret_cast<LPCSTR>(&getLibPath), &hModule)) {
136 // Handle error
137 return std::filesystem::path();
138 }
139
140 char buffer[MAX_PATH];
141 DWORD length = GetModuleFileNameA(hModule, buffer, MAX_PATH);
142 if (length == 0)
143 throw std::runtime_error("Could not get library path");
144
145 return std::filesystem::path(std::string(buffer, length));
146#else
147#eror "Unsupported platform"
148#endif
149}
150
151/// Get the list of directories to search for backend plugins.
152static std::vector<std::filesystem::path> getESIBackendDirectories() {
153 std::vector<std::filesystem::path> directories;
154
155 // First, check current directory.
156 directories.push_back(std::filesystem::current_path());
157
158 // Next, parse the ESI_BACKENDS environment variable and add those.
159 const char *esiBackends = std::getenv("ESI_BACKENDS");
160 if (esiBackends) {
161 // Use platform-specific path separator
162#ifdef _WIN32
163 const char separator = ';';
164#else
165 const char separator = ':';
166#endif
167
168 std::string pathsStr(esiBackends);
169 std::stringstream ss(pathsStr);
170 std::string path;
171
172 while (std::getline(ss, path, separator))
173 if (!path.empty())
174 directories.emplace_back(path);
175 }
176
177 // Next, try the directory of the executable.
178 directories.push_back(getExePath().parent_path());
179 // Finally, try the directory of the library.
180 directories.push_back(getLibPath().parent_path());
181
182 return directories;
183}
184
185/// Load a backend plugin dynamically. Plugins are expected to be named
186/// lib<BackendName>Backend.so and located in one of 1) CWD, 2) directories
187/// specified in ESI_BACKENDS environment variable, 3) in the same directory as
188/// the application, or 4) in the same directory as this library.
189static void loadBackend(Context &ctxt, std::string backend) {
190 Logger &logger = ctxt.getLogger();
191 backend[0] = toupper(backend[0]);
192
193 // Get the file name we are looking for.
194#ifdef __linux__
195 std::string backendFileName = "lib" + backend + "Backend.so";
196#elif _WIN32
197 std::string backendFileName = backend + "Backend.dll";
198#else
199#eror "Unsupported platform"
200#endif
201
202 // First, try the current directory.
203 std::filesystem::path backendPath;
204 // Next, try directories specified in ESI_BACKENDS environment variable.
205 std::vector<std::filesystem::path> esiBackendDirs =
207 bool found = false;
208 for (const auto &dir : esiBackendDirs) {
209 backendPath = dir / backendFileName;
210 logger.debug("CONNECT",
211 "trying to find backend plugin: " + backendPath.string());
212 if (std::filesystem::exists(backendPath)) {
213 found = true;
214 break;
215 }
216 }
217
218 // If the path was found, convert it to a string.
219 if (found) {
220 backendPath = std::filesystem::absolute(backendPath);
221 logger.debug("CONNECT", "found backend plugin: " + backendPath.string());
222 } else {
223 // If all else fails, just try the name.
224 backendPath = backendFileName;
225 logger.debug("CONNECT",
226 "trying to find backend plugin: " + backendFileName);
227 }
228
229 // Attempt to load it.
230#ifdef __linux__
231 void *handle = dlopen(backendPath.string().c_str(), RTLD_NOW | RTLD_GLOBAL);
232 if (!handle) {
233 std::string error(dlerror());
234 logger.error("CONNECT",
235 "while attempting to load backend plugin: " + error);
236 throw std::runtime_error("While attempting to load backend plugin: " +
237 error);
238 }
239#elif _WIN32
240 // Set the DLL directory to the same directory as the backend DLL in case it
241 // has transitive dependencies.
242 if (found) {
243 std::filesystem::path backendPathParent = backendPath.parent_path();
244 // If backendPath has no parent directory (e.g., it's a relative path or
245 // a filename without a directory), fallback to the current working
246 // directory. This ensures a valid directory is used for setting the DLL
247 // search path.
248 if (backendPathParent.empty())
249 backendPathParent = std::filesystem::current_path();
250 logger.debug("CONNECT", "setting DLL search directory to: " +
251 backendPathParent.string());
252 if (SetDllDirectoryA(backendPathParent.string().c_str()) == 0)
253 throw std::runtime_error("While setting DLL directory: " +
254 std::to_string(GetLastError()));
255 }
256
257 // Load the backend plugin.
258 HMODULE handle = LoadLibraryA(backendPath.string().c_str());
259 if (!handle) {
260 DWORD error = GetLastError();
261 // Get the error message string
262 LPSTR messageBuffer = nullptr;
263 size_t size = FormatMessageA(
264 FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
265 FORMAT_MESSAGE_IGNORE_INSERTS,
266 nullptr, error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
267 (LPSTR)&messageBuffer, 0, nullptr);
268
269 std::string errorMessage;
270 if (size > 0 && messageBuffer != nullptr) {
271 errorMessage = std::string(messageBuffer, size);
272 LocalFree(messageBuffer);
273 } else {
274 errorMessage = "Unknown error";
275 }
276
277 std::string fullError = "While attempting to load backend plugin '" +
278 backendPath.string() + "': " + errorMessage +
279 " (error code: " + std::to_string(error) + ")";
280
281 logger.error("CONNECT", fullError);
282 throw std::runtime_error(fullError);
283 }
284#else
285#eror "Unsupported platform"
286#endif
287 logger.info("CONNECT", "loaded backend plugin: " + backendPath.string());
288}
289
290namespace registry {
291namespace internal {
292
294public:
295 static std::map<std::string, BackendCreate> &get() {
296 static BackendRegistry instance;
297 return instance.backendRegistry;
298 }
299
300private:
301 std::map<std::string, BackendCreate> backendRegistry;
302};
303
304void registerBackend(const std::string &name, BackendCreate create) {
305 auto &registry = BackendRegistry::get();
306 if (registry.count(name))
307 throw std::runtime_error("Backend already exists in registry");
308 registry[name] = create;
309}
310} // namespace internal
311
312} // namespace registry
313
315 std::string connection) {
317 auto f = registry.find(backend);
318 if (f == registry.end()) {
319 // If it's not already found in the registry, try to load it dynamically.
320 loadBackend(*this, backend);
321 f = registry.find(backend);
322 if (f == registry.end()) {
323 ServiceImplDetails details;
324 details["backend"] = backend;
325 std::ostringstream loaded_backends;
326 bool first = true;
327 for (const auto &b : registry) {
328 if (!first)
329 loaded_backends << ", ";
330 loaded_backends << b.first;
331 first = false;
332 }
333 details["loaded_backends"] = loaded_backends.str();
334 getLogger().error("CONNECT", "backend '" + backend + "' not found",
335 &details);
336 throw std::runtime_error("Backend '" + backend + "' not found");
337 }
338 }
339 getLogger().info("CONNECT", "connecting to backend " + backend + " via '" +
340 connection + "'");
341 auto conn = f->second(*this, connection);
342 auto *connPtr = conn.get();
343 connections.emplace_back(std::move(conn));
344 return connPtr;
345}
346
348 Impl() {}
349 void start() { me = std::thread(&Impl::loop, this); }
350 void stop() {
351 shutdown = true;
352 me.join();
353 }
354 /// When there's data on any of the listenPorts, call the callback. This
355 /// method can be called from any thread.
356 void
357 addListener(std::initializer_list<ReadChannelPort *> listenPorts,
358 std::function<void(ReadChannelPort *, MessageData)> callback);
359
360 void addTask(std::function<void(void)> task) {
361 std::lock_guard<std::mutex> g(m);
362 taskList.push_back(task);
363 }
364
365private:
366 void loop();
367 volatile bool shutdown = false;
368 std::thread me;
369
370 // Protect the shared data structures.
371 std::mutex m;
372
373 // Map of read ports to callbacks.
374 std::map<ReadChannelPort *,
375 std::pair<std::function<void(ReadChannelPort *, MessageData)>,
376 std::future<MessageData>>>
378
379 /// Tasks which should be called on every loop iteration.
380 std::vector<std::function<void(void)>> taskList;
381};
382
383void AcceleratorServiceThread::Impl::loop() {
384 // These two variables should logically be in the loop, but this avoids
385 // reconstructing them on each iteration.
386 std::vector<std::tuple<ReadChannelPort *,
387 std::function<void(ReadChannelPort *, MessageData)>,
389 portUnlockWorkList;
390 std::vector<std::function<void(void)>> taskListCopy;
391 MessageData data;
392
393 while (!shutdown) {
394 // Ideally we'd have some wake notification here, but this sufficies for
395 // now.
396 // TODO: investigate better ways to do this. For now, just play nice with
397 // the other processes but don't waste time in between polling intervals.
398 std::this_thread::yield();
399
400 // Check and gather data from all the read ports we are monitoring. Put the
401 // callbacks to be called later so we can release the lock.
402 {
403 std::lock_guard<std::mutex> g(m);
404 for (auto &[channel, cbfPair] : listeners) {
405 assert(channel && "Null channel in listener list");
406 std::future<MessageData> &f = cbfPair.second;
407 if (f.wait_for(std::chrono::seconds(0)) == std::future_status::ready) {
408 portUnlockWorkList.emplace_back(channel, cbfPair.first, f.get());
409 f = channel->readAsync();
410 }
411 }
412 }
413
414 // Call the callbacks outside the lock.
415 for (auto [channel, cb, data] : portUnlockWorkList)
416 cb(channel, std::move(data));
417
418 // Clear the worklist for the next iteration.
419 portUnlockWorkList.clear();
420
421 // Call any tasks that have been added. Copy it first so we can release the
422 // lock ASAP.
423 {
424 std::lock_guard<std::mutex> g(m);
425 taskListCopy = taskList;
426 }
427 for (auto &task : taskListCopy)
428 task();
429 }
430}
431
432void AcceleratorServiceThread::Impl::addListener(
433 std::initializer_list<ReadChannelPort *> listenPorts,
434 std::function<void(ReadChannelPort *, MessageData)> callback) {
435 std::lock_guard<std::mutex> g(m);
436 for (auto port : listenPorts) {
437 if (listeners.count(port))
438 throw std::runtime_error("Port already has a listener");
439 listeners[port] = std::make_pair(callback, port->readAsync());
440 }
441}
442
444 : impl(std::make_unique<Impl>()) {
445 impl->start();
446}
448
450 if (impl) {
451 impl->stop();
452 impl.reset();
453 }
454}
455
456// When there's data on any of the listenPorts, call the callback. This is
457// kinda silly now that we have callback port support, especially given the
458// polling loop. Keep the functionality for now.
460 std::initializer_list<ReadChannelPort *> listenPorts,
461 std::function<void(ReadChannelPort *, MessageData)> callback) {
462 assert(impl && "Service thread not running");
463 impl->addListener(listenPorts, callback);
464}
465
467 assert(impl && "Service thread not running");
468 impl->addTask([&module]() { module.poll(); });
469}
470
472 if (serviceThread) {
473 serviceThread->stop();
474 serviceThread.reset();
475 }
476}
477
478} // namespace esi
assert(baseType &&"element must be base type")
Abstract class representing a connection to an accelerator.
Definition Accelerator.h:79
virtual Service * createService(Service::Type service, AppIDPath idPath, std::string implName, const ServiceImplDetails &details, const HWClientDetails &clients)=0
Called by getServiceImpl exclusively.
ServiceClass * getService(AppIDPath id={}, std::string implName={}, ServiceImplDetails details={}, HWClientDetails clients={})
Get a typed reference to a particular service type.
std::map< AppIDPath, BundleEngineMap > clientEngines
Mapping of clients to their servicing engines.
void registerEngine(AppIDPath idPath, std::unique_ptr< Engine > engine, const HWClientDetails &clients)
If createEngine is overridden, this method should be called to register the engine and all of the cha...
std::map< ServiceCacheKey, std::unique_ptr< Service > > serviceCache
std::unique_ptr< AcceleratorServiceThread > serviceThread
std::unique_ptr< Accelerator > ownedAccelerator
Accelerator object owned by this connection.
virtual void disconnect()
Disconnect from the accelerator cleanly.
std::map< AppIDPath, std::unique_ptr< Engine > > ownedEngines
Collection of owned engines.
virtual void createEngine(const std::string &engineTypeName, AppIDPath idPath, const ServiceImplDetails &details, const HWClientDetails &clients)
Create a new engine for channel communication with the accelerator.
AcceleratorServiceThread * getServiceThread()
Return a pointer to the accelerator 'service' thread (or threads).
AcceleratorConnection(Context &ctxt)
Accelerator * takeOwnership(std::unique_ptr< Accelerator > accel)
Assume ownership of an accelerator object.
Background thread which services various requests.
std::unique_ptr< Impl > impl
void stop()
Instruct the service thread to stop running.
void addPoll(HWModule &module)
Poll this module.
void addListener(std::initializer_list< ReadChannelPort * > listenPorts, std::function< void(ReadChannelPort *, MessageData)> callback)
When there's data on any of the listenPorts, call the callback.
Top level accelerator class.
Definition Accelerator.h:60
AcceleratorConnections, Accelerators, and Manifests must all share a context.
Definition Context.h:34
Logger & getLogger()
Definition Context.h:69
std::vector< std::unique_ptr< AcceleratorConnection > > connections
Definition Context.h:73
AcceleratorConnection * connect(std::string backend, std::string connection)
Connect to an accelerator backend.
Engines implement the actual channel communication between the host and the accelerator.
Definition Engines.h:42
Represents either the top level or an instance of a hardware module.
Definition Design.h:47
virtual void error(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an error.
Definition Logging.h:64
virtual void info(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an informational message.
Definition Logging.h:75
void debug(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report a debug message.
Definition Logging.h:83
A logical chunk of data representing serialized data.
Definition Common.h:113
A ChannelPort which reads data from the accelerator.
Definition Ports.h:124
std::map< std::string, BackendCreate > backendRegistry
static std::map< std::string, BackendCreate > & get()
static Service * createService(AcceleratorConnection *acc, Service::Type svcType, AppIDPath id, std::string implName, ServiceImplDetails details, HWClientDetails clients)
Create a service instance from the given details.
Definition Services.cpp:406
Parent class of all APIs modeled as 'services'.
Definition Services.h:57
const std::type_info & Type
Definition Services.h:59
void registerBackend(const std::string &name, BackendCreate create)
std::function< std::unique_ptr< AcceleratorConnection >(Context &, std::string)> BackendCreate
Backends can register themselves to be connected via a connection string.
std::unique_ptr< Engine > createEngine(AcceleratorConnection &conn, const std::string &dmaEngineName, AppIDPath idPath, const ServiceImplDetails &details, const HWClientDetails &clients)
Create an engine by name.
Definition Engines.cpp:507
Definition esi.py:1
static std::filesystem::path getExePath()
Get the path to the currently running executable.
std::map< std::string, std::any > ServiceImplDetails
Definition Common.h:108
static void loadBackend(Context &ctxt, std::string backend)
Load a backend plugin dynamically.
static std::filesystem::path getLibPath()
Get the path to the currently running shared library.
static std::vector< std::filesystem::path > getESIBackendDirectories()
Get the list of directories to search for backend plugins.
std::vector< HWClientDetail > HWClientDetails
Definition Common.h:107
std::map< ReadChannelPort *, std::pair< std::function< void(ReadChannelPort *, MessageData)>, std::future< MessageData > > > listeners
void addTask(std::function< void(void)> task)
void addListener(std::initializer_list< ReadChannelPort * > listenPorts, std::function< void(ReadChannelPort *, MessageData)> callback)
When there's data on any of the listenPorts, call the callback.
std::vector< std::function< void(void)> > taskList
Tasks which should be called on every loop iteration.