tensor_predictors/mvbernoulli/inst/include/threadPool.h

140 lines
4.4 KiB
C
Raw Permalink Normal View History

#ifndef THREADPOOL_INCLUDE_GUARD_H
#define THREADPOOL_INCLUDE_GUARD_H
#include <iostream>
#include <vector>
#include <queue>
#include <mutex>
#include <atomic>
#include <functional>
#include <condition_variable>
#include <thread>
// thread pool
class ThreadPool {
public:
ThreadPool(std::size_t n = std::thread::hardware_concurrency())
: _shutdown{false}, _running{0}
{
// Reserve max nr. of worker space
_workers.reserve(n);
// Launce `n` workers
for (std::size_t i = 0; i < n; ++i) {
_workers.emplace_back([this, i]() {
// setup infinit loop (continue untill told to shut down or
// everything is done)
for (;;) {
// setup a (skopped) lock to avoid inference
std::unique_lock<std::mutex> lock(_mtx);
// wait untill ether shutdown or work is available
// (wait iff `!_shutdown && _jobs.empty()`)
_pager.wait(lock, [&]() { return _shutdown || !_jobs.empty(); });
// in case of shutdown terminate the infinit loop after all
// jobs have been processes
if (_shutdown && _jobs.empty()) {
break; // releases the lock
}
// extract a job from the job queue
auto job = _jobs.front();
_jobs.pop();
// increment the running jobs counter (before releasing the
// lock as the number of outstanding and running jobs needs
// to be precise)
_running += 1;
// free the lock for other workers
lock.unlock();
// execute the job
job();
// decrement the running jobs counter
_running -= 1;
// and report a done job (to everyone listening)
_callback.notify_all();
}
});
}
}
~ThreadPool() {
{
// set shutdown with a lock to ensure that workers note the change
// in setting shutdown!
std::lock_guard<std::mutex> lock(_mtx);
_shutdown = true;
}
// notify all workers for the change
_pager.notify_all();
// finally join the worker threads into the main thread
for (auto& thr : _workers) { thr.join(); }
}
// Add jobs to the job queue
template <typename Fun, typename ...Args>
void push(Fun&& job, Args&&... args) {
// add a new task to the job queue with a lock
{
std::unique_lock<std::mutex> lock(_mtx);
_jobs.push([job, args...]() { job(args...); });
}
// notify one waiting worker that there is work to be done
_pager.notify_one();
}
// wait till all jobs have been processes
void wait() {
// infinit loop till all threads are idle
for (;;) {
// guard against job queue retriefel
std::unique_lock<std::mutex> lock(_mtx);
// wait for a callback (done job) to check again but only if there
// are any jobs to be performed
_callback.wait(lock, [&]() { return _jobs.empty() && !_running; });
if (_jobs.empty() && !_running) {
break;
}
// lock released by end of skope
}
}
// clears the job queue
void clear() {
// lock the queue
std::lock_guard<std::mutex> lock(_mtx);
// and swap the jobs queue with an empty queue
std::queue<std::function<void()>>().swap(_jobs);
}
// get number of currently running jobs
std::size_t running_jobs() { return _running; }
// get number of queued (waiting for execution) jobs
std::size_t queued_jobs() { return _jobs.size(); }
// get number of worker threads
std::size_t workers() { return _workers.size(); }
private:
bool _shutdown;
std::size_t _running; // number of running jobs
std::vector<std::thread> _workers;
std::queue<std::function<void()>> _jobs;
std::condition_variable _pager; // for wayking idle workers
std::condition_variable _callback; // for workers reporting a done job
std::mutex _mtx; // mutex base for cond. variables
};
#endif /* THREADPOOL_INCLUDE_GUARD_H */