CIRCT 20.0.0git
Loading...
Searching...
No Matches
Accelerator.cpp
Go to the documentation of this file.
1//===- Accelerator.cpp - ESI accelerator system API -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// DO NOT EDIT!
10// This file is distributed as part of an ESI package. The source for this file
11// should always be modified within CIRCT (lib/dialect/ESI/runtime/cpp/).
12//
13//===----------------------------------------------------------------------===//
14
15#include "esi/Accelerator.h"
16
17#include <cassert>
18#include <filesystem>
19#include <map>
20#include <stdexcept>
21
22#include <iostream>
23
24#ifdef __linux__
25#include <dlfcn.h>
26#include <linux/limits.h>
27#include <unistd.h>
28#elif _WIN32
29#include <windows.h>
30#endif
31
32using namespace esi;
33using namespace esi::services;
34
35namespace esi {
37 : ctxt(ctxt), serviceThread(nullptr) {}
39
41 if (!serviceThread)
42 serviceThread = std::make_unique<AcceleratorServiceThread>();
43 return serviceThread.get();
44}
45
47 AppIDPath id,
48 std::string implName,
49 ServiceImplDetails details,
50 HWClientDetails clients) {
51 std::unique_ptr<Service> &cacheEntry = serviceCache[make_tuple(&svcType, id)];
52 if (cacheEntry == nullptr) {
53 Service *svc = createService(svcType, id, implName, details, clients);
54 if (!svc)
55 svc = ServiceRegistry::createService(this, svcType, id, implName, details,
56 clients);
57 if (!svc)
58 return nullptr;
59 cacheEntry = std::unique_ptr<Service>(svc);
60 }
61 return cacheEntry.get();
62}
63
65AcceleratorConnection::takeOwnership(std::unique_ptr<Accelerator> acc) {
66 Accelerator *ret = acc.get();
67 ownedAccelerators.push_back(std::move(acc));
68 return ret;
69}
70
71/// Get the path to the currently running executable.
72static std::filesystem::path getExePath() {
73#ifdef __linux__
74 char result[PATH_MAX];
75 ssize_t count = readlink("/proc/self/exe", result, PATH_MAX);
76 if (count == -1)
77 throw std::runtime_error("Could not get executable path");
78 return std::filesystem::path(std::string(result, count));
79#elif _WIN32
80 char buffer[MAX_PATH];
81 DWORD length = GetModuleFileNameA(NULL, buffer, MAX_PATH);
82 if (length == 0)
83 throw std::runtime_error("Could not get executable path");
84 return std::filesystem::path(std::string(buffer, length));
85#else
86#eror "Unsupported platform"
87#endif
88}
89
90/// Get the path to the currently running shared library.
91static std::filesystem::path getLibPath() {
92#ifdef __linux__
93 Dl_info dl_info;
94 dladdr((void *)getLibPath, &dl_info);
95 return std::filesystem::path(std::string(dl_info.dli_fname));
96#elif _WIN32
97 HMODULE hModule = NULL;
98 if (!GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
99 GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
100 reinterpret_cast<LPCSTR>(&getLibPath), &hModule)) {
101 // Handle error
102 return std::filesystem::path();
103 }
104
105 char buffer[MAX_PATH];
106 DWORD length = GetModuleFileNameA(hModule, buffer, MAX_PATH);
107 if (length == 0)
108 throw std::runtime_error("Could not get library path");
109
110 return std::filesystem::path(std::string(buffer, length));
111#else
112#eror "Unsupported platform"
113#endif
114}
115
116/// Load a backend plugin dynamically. Plugins are expected to be named
117/// lib<BackendName>Backend.so and located in one of 1) CWD, 2) in the same
118/// directory as the application, or 3) in the same directory as this library.
119static void loadBackend(Context &ctxt, std::string backend) {
120 Logger &logger = ctxt.getLogger();
121 backend[0] = toupper(backend[0]);
122
123 // Get the file name we are looking for.
124#ifdef __linux__
125 std::string backendFileName = "lib" + backend + "Backend.so";
126#elif _WIN32
127 std::string backendFileName = backend + "Backend.dll";
128#else
129#eror "Unsupported platform"
130#endif
131
132 // Look for library using the C++ std API.
133 // TODO: once the runtime has a logging framework, log the paths we are
134 // trying.
135
136 // First, try the current directory.
137 std::filesystem::path backendPath = backendFileName;
138 std::string backendPathStr;
139 logger.debug("CONNECT",
140 "trying to load backend plugin: " + backendPath.string());
141 if (!std::filesystem::exists(backendPath)) {
142 // Next, try the directory of the executable.
143 backendPath = getExePath().parent_path().append(backendFileName);
144 logger.debug("CONNECT",
145 "trying to load backend plugin: " + backendPath.string());
146 if (!std::filesystem::exists(backendPath)) {
147 // Finally, try the directory of the library.
148 backendPath = getLibPath().parent_path().append(backendFileName);
149 logger.debug("CONNECT",
150 "trying to load backend plugin: " + backendPath.string());
151 if (!std::filesystem::exists(backendPath)) {
152 // If all else fails, just try the name.
153 backendPathStr = backendFileName;
154 logger.debug("CONNECT",
155 "trying to load backend plugin: " + backendPathStr);
156 }
157 }
158 }
159 // If the path was found, convert it to a string.
160 if (backendPathStr.empty())
161 backendPathStr = backendPath.string();
162 else
163 // Otherwise, signal that the path wasn't found by clearing the path and
164 // just use the name. (This is only used on Windows to add the same
165 // directory as the backend DLL to the DLL search path.)
166 backendPath.clear();
167
168 // Attempt to load it.
169#ifdef __linux__
170 void *handle = dlopen(backendPathStr.c_str(), RTLD_NOW | RTLD_GLOBAL);
171 if (!handle) {
172 std::string error(dlerror());
173 logger.error("CONNECT",
174 "while attempting to load backend plugin: " + error);
175 throw std::runtime_error("While attempting to load backend plugin: " +
176 error);
177 }
178#elif _WIN32
179 // Set the DLL directory to the same directory as the backend DLL in case it
180 // has transitive dependencies.
181 if (backendPath != std::filesystem::path()) {
182 std::filesystem::path backendPathParent = backendPath.parent_path();
183 if (SetDllDirectoryA(backendPathParent.string().c_str()) == 0)
184 throw std::runtime_error("While setting DLL directory: " +
185 std::to_string(GetLastError()));
186 }
187
188 // Load the backend plugin.
189 HMODULE handle = LoadLibraryA(backendPathStr.c_str());
190 if (!handle) {
191 DWORD error = GetLastError();
192 if (error == ERROR_MOD_NOT_FOUND) {
193 logger.error("CONNECT", "while attempting to load backend plugin: " +
194 backendPathStr + " not found");
195 throw std::runtime_error("While attempting to load backend plugin: " +
196 backendPathStr + " not found");
197 }
198 logger.error("CONNECT", "while attempting to load backend plugin: " +
199 std::to_string(error));
200 throw std::runtime_error("While attempting to load backend plugin: " +
201 std::to_string(error));
202 }
203#else
204#eror "Unsupported platform"
205#endif
206 logger.info("CONNECT", "loaded backend plugin: " + backendPathStr);
207}
208
209namespace registry {
210namespace internal {
211
213public:
214 static std::map<std::string, BackendCreate> &get() {
215 static BackendRegistry instance;
216 return instance.backendRegistry;
217 }
218
219private:
220 std::map<std::string, BackendCreate> backendRegistry;
221};
222
223void registerBackend(const std::string &name, BackendCreate create) {
224 auto &registry = BackendRegistry::get();
225 if (registry.count(name))
226 throw std::runtime_error("Backend already exists in registry");
227 registry[name] = create;
228}
229} // namespace internal
230
231std::unique_ptr<AcceleratorConnection> connect(Context &ctxt,
232 const std::string &backend,
233 const std::string &connection) {
234 auto &registry = internal::BackendRegistry::get();
235 auto f = registry.find(backend);
236 if (f == registry.end()) {
237 // If it's not already found in the registry, try to load it dynamically.
238 loadBackend(ctxt, backend);
239 f = registry.find(backend);
240 if (f == registry.end()) {
241 ctxt.getLogger().error("CONNECT", "backend '" + backend + "' not found");
242 throw std::runtime_error("Backend '" + backend + "' not found");
243 }
244 }
245 ctxt.getLogger().info("CONNECT", "connecting to backend " + backend +
246 " via '" + connection + "'");
247 return f->second(ctxt, connection);
248}
249
250} // namespace registry
251
253 Impl() {}
254 void start() { me = std::thread(&Impl::loop, this); }
255 void stop() {
256 shutdown = true;
257 me.join();
258 }
259 /// When there's data on any of the listenPorts, call the callback. This
260 /// method can be called from any thread.
261 void
262 addListener(std::initializer_list<ReadChannelPort *> listenPorts,
263 std::function<void(ReadChannelPort *, MessageData)> callback);
264
265 void addTask(std::function<void(void)> task) {
266 std::lock_guard<std::mutex> g(m);
267 taskList.push_back(task);
268 }
269
270private:
271 void loop();
272 volatile bool shutdown = false;
273 std::thread me;
274
275 // Protect the shared data structures.
276 std::mutex m;
277
278 // Map of read ports to callbacks.
279 std::map<ReadChannelPort *,
280 std::pair<std::function<void(ReadChannelPort *, MessageData)>,
281 std::future<MessageData>>>
283
284 /// Tasks which should be called on every loop iteration.
285 std::vector<std::function<void(void)>> taskList;
286};
287
288void AcceleratorServiceThread::Impl::loop() {
289 // These two variables should logically be in the loop, but this avoids
290 // reconstructing them on each iteration.
291 std::vector<std::tuple<ReadChannelPort *,
292 std::function<void(ReadChannelPort *, MessageData)>,
294 portUnlockWorkList;
295 std::vector<std::function<void(void)>> taskListCopy;
296 MessageData data;
297
298 while (!shutdown) {
299 // Ideally we'd have some wake notification here, but this sufficies for
300 // now.
301 // TODO: investigate better ways to do this.
302 std::this_thread::sleep_for(std::chrono::microseconds(100));
303
304 // Check and gather data from all the read ports we are monitoring. Put the
305 // callbacks to be called later so we can release the lock.
306 {
307 std::lock_guard<std::mutex> g(m);
308 for (auto &[channel, cbfPair] : listeners) {
309 assert(channel && "Null channel in listener list");
310 std::future<MessageData> &f = cbfPair.second;
311 if (f.wait_for(std::chrono::seconds(0)) == std::future_status::ready) {
312 portUnlockWorkList.emplace_back(channel, cbfPair.first, f.get());
313 f = channel->readAsync();
314 }
315 }
316 }
317
318 // Call the callbacks outside the lock.
319 for (auto [channel, cb, data] : portUnlockWorkList)
320 cb(channel, std::move(data));
321
322 // Clear the worklist for the next iteration.
323 portUnlockWorkList.clear();
324
325 // Call any tasks that have been added. Copy it first so we can release the
326 // lock ASAP.
327 {
328 std::lock_guard<std::mutex> g(m);
329 taskListCopy = taskList;
330 }
331 for (auto &task : taskListCopy)
332 task();
333 }
334}
335
336void AcceleratorServiceThread::Impl::addListener(
337 std::initializer_list<ReadChannelPort *> listenPorts,
338 std::function<void(ReadChannelPort *, MessageData)> callback) {
339 std::lock_guard<std::mutex> g(m);
340 for (auto port : listenPorts) {
341 if (listeners.count(port))
342 throw std::runtime_error("Port already has a listener");
343 listeners[port] = std::make_pair(callback, port->readAsync());
344 }
345}
346
347} // namespace esi
348
350 : impl(std::make_unique<Impl>()) {
351 impl->start();
352}
354
356 if (impl) {
357 impl->stop();
358 impl.reset();
359 }
360}
361
362// When there's data on any of the listenPorts, call the callback. This is
363// kinda silly now that we have callback port support, especially given the
364// polling loop. Keep the functionality for now.
366 std::initializer_list<ReadChannelPort *> listenPorts,
367 std::function<void(ReadChannelPort *, MessageData)> callback) {
368 assert(impl && "Service thread not running");
369 impl->addListener(listenPorts, callback);
370}
371
373 assert(impl && "Service thread not running");
374 impl->addTask([&module]() { module.poll(); });
375}
376
378 if (serviceThread) {
379 serviceThread->stop();
380 serviceThread.reset();
381 }
382}
assert(baseType &&"element must be base type")
virtual void disconnect()
Disconnect from the accelerator cleanly.
virtual Service * createService(Service::Type service, AppIDPath idPath, std::string implName, const ServiceImplDetails &details, const HWClientDetails &clients)=0
Called by getServiceImpl exclusively.
ServiceClass * getService(AppIDPath id={}, std::string implName={}, ServiceImplDetails details={}, HWClientDetails clients={})
Get a typed reference to a particular service type.
std::map< ServiceCacheKey, std::unique_ptr< Service > > serviceCache
std::unique_ptr< AcceleratorServiceThread > serviceThread
std::vector< std::unique_ptr< Accelerator > > ownedAccelerators
List of accelerator objects owned by this connection.
AcceleratorServiceThread * getServiceThread()
Return a pointer to the accelerator 'service' thread (or threads).
AcceleratorConnection(Context &ctxt)
Accelerator * takeOwnership(std::unique_ptr< Accelerator > accel)
Assume ownership of an accelerator object.
Background thread which services various requests.
void stop()
Instruct the service thread to stop running.
void addListener(std::initializer_list< ReadChannelPort * > listenPorts, std::function< void(ReadChannelPort *, MessageData)> callback)
When there's data on any of the listenPorts, call the callback.
std::unique_ptr< Impl > impl
void addPoll(HWModule &module)
Poll this module.
Top level accelerator class.
Definition Accelerator.h:59
AcceleratorConnections, Accelerators, and Manifests must all share a context.
Definition Context.h:31
Represents either the top level or an instance of a hardware module.
Definition Design.h:47
virtual void error(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an error.
Definition Logging.h:60
virtual void info(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an informational message.
Definition Logging.h:71
void debug(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report a debug message.
Definition Logging.h:79
A logical chunk of data representing serialized data.
Definition Common.h:103
A ChannelPort which reads data from the accelerator.
Definition Ports.h:103
std::map< std::string, BackendCreate > backendRegistry
static std::map< std::string, BackendCreate > & get()
static Service * createService(AcceleratorConnection *acc, Service::Type svcType, AppIDPath id, std::string implName, ServiceImplDetails details, HWClientDetails clients)
Create a service instance from the given details.
Definition Services.cpp:293
Parent class of all APIs modeled as 'services'.
Definition Services.h:45
const std::type_info & Type
Definition Services.h:47
void registerBackend(const std::string &name, BackendCreate create)
std::function< std::unique_ptr< AcceleratorConnection >(Context &, std::string)> BackendCreate
Backends can register themselves to be connected via a connection string.
std::unique_ptr< AcceleratorConnection > connect(Context &ctxt, const std::string &backend, const std::string &connection)
Definition esi.py:1
static std::filesystem::path getExePath()
Get the path to the currently running executable.
std::map< std::string, std::any > ServiceImplDetails
Definition Common.h:98
static void loadBackend(Context &ctxt, std::string backend)
Load a backend plugin dynamically.
static std::filesystem::path getLibPath()
Get the path to the currently running shared library.
std::vector< HWClientDetail > HWClientDetails
Definition Common.h:97
std::map< ReadChannelPort *, std::pair< std::function< void(ReadChannelPort *, MessageData)>, std::future< MessageData > > > listeners
void addTask(std::function< void(void)> task)
void addListener(std::initializer_list< ReadChannelPort * > listenPorts, std::function< void(ReadChannelPort *, MessageData)> callback)
When there's data on any of the listenPorts, call the callback.
std::vector< std::function< void(void)> > taskList
Tasks which should be called on every loop iteration.