CIRCT  20.0.0git
Accelerator.cpp
Go to the documentation of this file.
1 //===- Accelerator.cpp - ESI accelerator system API -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DO NOT EDIT!
10 // This file is distributed as part of an ESI package. The source for this file
11 // should always be modified within CIRCT (lib/dialect/ESI/runtime/cpp/).
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "esi/Accelerator.h"
16 
17 #include <cassert>
18 #include <filesystem>
19 #include <map>
20 #include <stdexcept>
21 
22 #include <iostream>
23 
24 #ifdef __linux__
25 #include <dlfcn.h>
26 #include <linux/limits.h>
27 #include <unistd.h>
28 #elif _WIN32
29 #include <windows.h>
30 #endif
31 
32 using namespace esi;
33 using namespace esi::services;
34 
35 namespace esi {
37  : ctxt(ctxt), serviceThread(nullptr) {}
39 
41  if (!serviceThread)
42  serviceThread = std::make_unique<AcceleratorServiceThread>();
43  return serviceThread.get();
44 }
45 
47  AppIDPath id,
48  std::string implName,
49  ServiceImplDetails details,
50  HWClientDetails clients) {
51  std::unique_ptr<Service> &cacheEntry = serviceCache[make_tuple(&svcType, id)];
52  if (cacheEntry == nullptr) {
53  Service *svc = createService(svcType, id, implName, details, clients);
54  if (!svc)
55  svc = ServiceRegistry::createService(this, svcType, id, implName, details,
56  clients);
57  if (!svc)
58  return nullptr;
59  cacheEntry = std::unique_ptr<Service>(svc);
60  }
61  return cacheEntry.get();
62 }
63 
65 AcceleratorConnection::takeOwnership(std::unique_ptr<Accelerator> acc) {
66  Accelerator *ret = acc.get();
67  ownedAccelerators.push_back(std::move(acc));
68  return ret;
69 }
70 
71 /// Get the path to the currently running executable.
72 static std::filesystem::path getExePath() {
73 #ifdef __linux__
74  char result[PATH_MAX];
75  ssize_t count = readlink("/proc/self/exe", result, PATH_MAX);
76  if (count == -1)
77  throw std::runtime_error("Could not get executable path");
78  return std::filesystem::path(std::string(result, count));
79 #elif _WIN32
80  char buffer[MAX_PATH];
81  DWORD length = GetModuleFileNameA(NULL, buffer, MAX_PATH);
82  if (length == 0)
83  throw std::runtime_error("Could not get executable path");
84  return std::filesystem::path(std::string(buffer, length));
85 #else
86 #eror "Unsupported platform"
87 #endif
88 }
89 
90 /// Get the path to the currently running shared library.
91 static std::filesystem::path getLibPath() {
92 #ifdef __linux__
93  Dl_info dl_info;
94  dladdr((void *)getLibPath, &dl_info);
95  return std::filesystem::path(std::string(dl_info.dli_fname));
96 #elif _WIN32
97  HMODULE hModule = NULL;
98  if (!GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
99  GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
100  reinterpret_cast<LPCSTR>(&getLibPath), &hModule)) {
101  // Handle error
102  return std::filesystem::path();
103  }
104 
105  char buffer[MAX_PATH];
106  DWORD length = GetModuleFileNameA(hModule, buffer, MAX_PATH);
107  if (length == 0)
108  throw std::runtime_error("Could not get library path");
109 
110  return std::filesystem::path(std::string(buffer, length));
111 #else
112 #eror "Unsupported platform"
113 #endif
114 }
115 
116 /// Load a backend plugin dynamically. Plugins are expected to be named
117 /// lib<BackendName>Backend.so and located in one of 1) CWD, 2) in the same
118 /// directory as the application, or 3) in the same directory as this library.
119 static void loadBackend(Context &ctxt, std::string backend) {
120  Logger &logger = ctxt.getLogger();
121  backend[0] = toupper(backend[0]);
122 
123  // Get the file name we are looking for.
124 #ifdef __linux__
125  std::string backendFileName = "lib" + backend + "Backend.so";
126 #elif _WIN32
127  std::string backendFileName = backend + "Backend.dll";
128 #else
129 #eror "Unsupported platform"
130 #endif
131 
132  // Look for library using the C++ std API.
133  // TODO: once the runtime has a logging framework, log the paths we are
134  // trying.
135 
136  // First, try the current directory.
137  std::filesystem::path backendPath = backendFileName;
138  std::string backendPathStr;
139  logger.debug("CONNECT",
140  "trying to load backend plugin: " + backendPath.string());
141  if (!std::filesystem::exists(backendPath)) {
142  // Next, try the directory of the executable.
143  backendPath = getExePath().parent_path().append(backendFileName);
144  logger.debug("CONNECT",
145  "trying to load backend plugin: " + backendPath.string());
146  if (!std::filesystem::exists(backendPath)) {
147  // Finally, try the directory of the library.
148  backendPath = getLibPath().parent_path().append(backendFileName);
149  logger.debug("CONNECT",
150  "trying to load backend plugin: " + backendPath.string());
151  if (!std::filesystem::exists(backendPath)) {
152  // If all else fails, just try the name.
153  backendPathStr = backendFileName;
154  logger.debug("CONNECT",
155  "trying to load backend plugin: " + backendPathStr);
156  }
157  }
158  }
159  // If the path was found, convert it to a string.
160  if (backendPathStr.empty())
161  backendPathStr = backendPath.string();
162  else
163  // Otherwise, signal that the path wasn't found by clearing the path and
164  // just use the name. (This is only used on Windows to add the same
165  // directory as the backend DLL to the DLL search path.)
166  backendPath.clear();
167 
168  // Attempt to load it.
169 #ifdef __linux__
170  void *handle = dlopen(backendPathStr.c_str(), RTLD_NOW | RTLD_GLOBAL);
171  if (!handle) {
172  std::string error(dlerror());
173  logger.error("CONNECT",
174  "while attempting to load backend plugin: " + error);
175  throw std::runtime_error("While attempting to load backend plugin: " +
176  error);
177  }
178 #elif _WIN32
179  // Set the DLL directory to the same directory as the backend DLL in case it
180  // has transitive dependencies.
181  if (backendPath != std::filesystem::path()) {
182  std::filesystem::path backendPathParent = backendPath.parent_path();
183  if (SetDllDirectoryA(backendPathParent.string().c_str()) == 0)
184  throw std::runtime_error("While setting DLL directory: " +
185  std::to_string(GetLastError()));
186  }
187 
188  // Load the backend plugin.
189  HMODULE handle = LoadLibraryA(backendPathStr.c_str());
190  if (!handle) {
191  DWORD error = GetLastError();
192  if (error == ERROR_MOD_NOT_FOUND) {
193  logger.error("CONNECT", "while attempting to load backend plugin: " +
194  backendPathStr + " not found");
195  throw std::runtime_error("While attempting to load backend plugin: " +
196  backendPathStr + " not found");
197  }
198  logger.error("CONNECT", "while attempting to load backend plugin: " +
199  std::to_string(error));
200  throw std::runtime_error("While attempting to load backend plugin: " +
201  std::to_string(error));
202  }
203 #else
204 #eror "Unsupported platform"
205 #endif
206  logger.info("CONNECT", "loaded backend plugin: " + backendPathStr);
207 }
208 
209 namespace registry {
210 namespace internal {
211 
213 public:
214  static std::map<std::string, BackendCreate> &get() {
215  static BackendRegistry instance;
216  return instance.backendRegistry;
217  }
218 
219 private:
220  std::map<std::string, BackendCreate> backendRegistry;
221 };
222 
223 void registerBackend(const std::string &name, BackendCreate create) {
224  auto &registry = BackendRegistry::get();
225  if (registry.count(name))
226  throw std::runtime_error("Backend already exists in registry");
227  registry[name] = create;
228 }
229 } // namespace internal
230 
231 std::unique_ptr<AcceleratorConnection> connect(Context &ctxt,
232  const std::string &backend,
233  const std::string &connection) {
234  auto &registry = internal::BackendRegistry::get();
235  auto f = registry.find(backend);
236  if (f == registry.end()) {
237  // If it's not already found in the registry, try to load it dynamically.
238  loadBackend(ctxt, backend);
239  f = registry.find(backend);
240  if (f == registry.end()) {
241  ctxt.getLogger().error("CONNECT", "backend '" + backend + "' not found");
242  throw std::runtime_error("Backend '" + backend + "' not found");
243  }
244  }
245  ctxt.getLogger().info("CONNECT", "connecting to backend " + backend +
246  " via '" + connection + "'");
247  return f->second(ctxt, connection);
248 }
249 
250 } // namespace registry
251 
253  Impl() {}
254  void start() { me = std::thread(&Impl::loop, this); }
255  void stop() {
256  shutdown = true;
257  me.join();
258  }
259  /// When there's data on any of the listenPorts, call the callback. This
260  /// method can be called from any thread.
261  void
262  addListener(std::initializer_list<ReadChannelPort *> listenPorts,
263  std::function<void(ReadChannelPort *, MessageData)> callback);
264 
265  void addTask(std::function<void(void)> task) {
266  std::lock_guard<std::mutex> g(m);
267  taskList.push_back(task);
268  }
269 
270 private:
271  void loop();
272  volatile bool shutdown = false;
273  std::thread me;
274 
275  // Protect the shared data structures.
276  std::mutex m;
277 
278  // Map of read ports to callbacks.
279  std::map<ReadChannelPort *,
280  std::pair<std::function<void(ReadChannelPort *, MessageData)>,
281  std::future<MessageData>>>
283 
284  /// Tasks which should be called on every loop iteration.
285  std::vector<std::function<void(void)>> taskList;
286 };
287 
289  // These two variables should logically be in the loop, but this avoids
290  // reconstructing them on each iteration.
291  std::vector<std::tuple<ReadChannelPort *,
292  std::function<void(ReadChannelPort *, MessageData)>,
293  MessageData>>
294  portUnlockWorkList;
295  std::vector<std::function<void(void)>> taskListCopy;
296  MessageData data;
297 
298  while (!shutdown) {
299  // Ideally we'd have some wake notification here, but this sufficies for
300  // now.
301  // TODO: investigate better ways to do this.
302  std::this_thread::sleep_for(std::chrono::microseconds(100));
303 
304  // Check and gather data from all the read ports we are monitoring. Put the
305  // callbacks to be called later so we can release the lock.
306  {
307  std::lock_guard<std::mutex> g(m);
308  for (auto &[channel, cbfPair] : listeners) {
309  assert(channel && "Null channel in listener list");
310  std::future<MessageData> &f = cbfPair.second;
311  if (f.wait_for(std::chrono::seconds(0)) == std::future_status::ready) {
312  portUnlockWorkList.emplace_back(channel, cbfPair.first, f.get());
313  f = channel->readAsync();
314  }
315  }
316  }
317 
318  // Call the callbacks outside the lock.
319  for (auto [channel, cb, data] : portUnlockWorkList)
320  cb(channel, std::move(data));
321 
322  // Clear the worklist for the next iteration.
323  portUnlockWorkList.clear();
324 
325  // Call any tasks that have been added. Copy it first so we can release the
326  // lock ASAP.
327  {
328  std::lock_guard<std::mutex> g(m);
329  taskListCopy = taskList;
330  }
331  for (auto &task : taskListCopy)
332  task();
333  }
334 }
335 
337  std::initializer_list<ReadChannelPort *> listenPorts,
338  std::function<void(ReadChannelPort *, MessageData)> callback) {
339  std::lock_guard<std::mutex> g(m);
340  for (auto port : listenPorts) {
341  if (listeners.count(port))
342  throw std::runtime_error("Port already has a listener");
343  listeners[port] = std::make_pair(callback, port->readAsync());
344  }
345 }
346 
347 } // namespace esi
348 
350  : impl(std::make_unique<Impl>()) {
351  impl->start();
352 }
354 
356  if (impl) {
357  impl->stop();
358  impl.reset();
359  }
360 }
361 
362 // When there's data on any of the listenPorts, call the callback. This is
363 // kinda silly now that we have callback port support, especially given the
364 // polling loop. Keep the functionality for now.
366  std::initializer_list<ReadChannelPort *> listenPorts,
367  std::function<void(ReadChannelPort *, MessageData)> callback) {
368  assert(impl && "Service thread not running");
369  impl->addListener(listenPorts, callback);
370 }
371 
373  assert(impl && "Service thread not running");
374  impl->addTask([&module]() { module.poll(); });
375 }
376 
378  if (serviceThread) {
379  serviceThread->stop();
380  serviceThread.reset();
381  }
382 }
assert(baseType &&"element must be base type")
virtual void disconnect()
Disconnect from the accelerator cleanly.
ServiceClass * getService(AppIDPath id={}, std::string implName={}, ServiceImplDetails details={}, HWClientDetails clients={})
Get a typed reference to a particular service type.
Definition: Accelerator.h:112
std::map< ServiceCacheKey, std::unique_ptr< Service > > serviceCache
Definition: Accelerator.h:145
std::unique_ptr< AcceleratorServiceThread > serviceThread
Definition: Accelerator.h:147
virtual Service * createService(Service::Type service, AppIDPath idPath, std::string implName, const ServiceImplDetails &details, const HWClientDetails &clients)=0
Called by getServiceImpl exclusively.
std::vector< std::unique_ptr< Accelerator > > ownedAccelerators
List of accelerator objects owned by this connection.
Definition: Accelerator.h:151
AcceleratorServiceThread * getServiceThread()
Return a pointer to the accelerator 'service' thread (or threads).
Definition: Accelerator.cpp:40
AcceleratorConnection(Context &ctxt)
Definition: Accelerator.cpp:36
Accelerator * takeOwnership(std::unique_ptr< Accelerator > accel)
Assume ownership of an accelerator object.
Definition: Accelerator.cpp:65
Background thread which services various requests.
Definition: Accelerator.h:186
void stop()
Instruct the service thread to stop running.
void addListener(std::initializer_list< ReadChannelPort * > listenPorts, std::function< void(ReadChannelPort *, MessageData)> callback)
When there's data on any of the listenPorts, call the callback.
std::unique_ptr< Impl > impl
Definition: Accelerator.h:204
void addPoll(HWModule &module)
Poll this module.
Top level accelerator class.
Definition: Accelerator.h:59
AcceleratorConnections, Accelerators, and Manifests must all share a context.
Definition: Context.h:31
Represents either the top level or an instance of a hardware module.
Definition: Design.h:47
bool poll()
Master poll method.
Definition: Design.cpp:50
virtual void error(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an error.
Definition: Logging.h:60
virtual void info(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an informational message.
Definition: Logging.h:71
void debug(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report a debug message.
Definition: Logging.h:79
A logical chunk of data representing serialized data.
Definition: Common.h:103
A ChannelPort which reads data from the accelerator.
Definition: Ports.h:103
static std::map< std::string, BackendCreate > & get()
std::map< std::string, BackendCreate > backendRegistry
Parent class of all APIs modeled as 'services'.
Definition: Services.h:45
const std::type_info & Type
Definition: Services.h:47
void registerBackend(const std::string &name, BackendCreate create)
std::function< std::unique_ptr< AcceleratorConnection >(Context &, std::string)> BackendCreate
Backends can register themselves to be connected via a connection string.
Definition: Accelerator.h:166
std::unique_ptr< AcceleratorConnection > connect(Context &ctxt, const std::string &backend, const std::string &connection)
Definition: esi.py:1
static std::filesystem::path getExePath()
Get the path to the currently running executable.
Definition: Accelerator.cpp:72
std::map< std::string, std::any > ServiceImplDetails
Definition: Common.h:98
static void loadBackend(Context &ctxt, std::string backend)
Load a backend plugin dynamically.
static std::filesystem::path getLibPath()
Get the path to the currently running shared library.
Definition: Accelerator.cpp:91
std::vector< HWClientDetail > HWClientDetails
Definition: Common.h:97
std::map< ReadChannelPort *, std::pair< std::function< void(ReadChannelPort *, MessageData)>, std::future< MessageData > > > listeners
void addTask(std::function< void(void)> task)
void addListener(std::initializer_list< ReadChannelPort * > listenPorts, std::function< void(ReadChannelPort *, MessageData)> callback)
When there's data on any of the listenPorts, call the callback.
std::vector< std::function< void(void)> > taskList
Tasks which should be called on every loop iteration.