I have developed a threadpool using the standard c++11 features and am looking for feedback.
Right now I think the implementation is pretty solid, but as I recently learned about multithreading I'm not sure if I got a solid grasp on it.
Code
Threadpool.h
#pragma once
#include <thread>
#include <future>
#include <condition_variable>
#include <mutex>
#include <list>
#include <memory>
#include <cassert>
#include <algorithm>
namespace threadpool {
template <class Result> class TaskHandler;
class Threadpool {
public:
// Defines how the tasks will be ordered by default
enum TaskOrdering {
FIFO, // First In First OUT
LIFO // Last In First OUT
};
// Priority with which a task will be queued
enum TaskPriority {
MAX, // The task will be the next one to execute
DEFAULT, // FIFO or LIFO
MIN // The task will be put past all stored tasks
};
/**
* @brief Create a threadpool holding a certain number of threads
*
* @param int number of workers
* @param ordering default ordering function for tasks
* FIFO = First In First Out
* LIFO = Last In First Out
*/
explicit Threadpool(unsigned int nWorkers, TaskOrdering ordering = FIFO) :
mOrdering(ordering),
mThreadsToKill(0),
mNWorkers(nWorkers),
mPoolDestroyed(false)
{
for (unsigned int i = 0; i < nWorkers; ++i) {
mWorkers.emplace_back(&Threadpool::workerFunction, this);
}
}
/**
* @brief Join all worker threads, leaving all remaining tasks undone
*/
~Threadpool() {
{
std::lock_guard<std::mutex> lock(mTasksMutex);
mPoolDestroyed = true;
// cancel possible thread suicides
mThreadsToKill = 0;
}
mTasksCond.notify_all();
for (auto it = std::begin(mWorkers); it != std::end(mWorkers); ++it) {
it->join();
}
}
/**
* @brief Queue a task and return a handle to it
*
* @param fn Callable
* @param args Arguments for fn
* @tparam Priority
* @return Handle to the task
*/
template <TaskPriority Priority = DEFAULT, class FN, class... Args>
auto queueAndHandleTask(FN&& fn, Args&&... args)
-> TaskHandler<typename std::result_of<FN(Args...)>::type>
{
using return_type = typename std::result_of<FN(Args...)>::type;
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<FN>(fn), std::forward<Args>(args)...)
);
queueTask<Priority>([task](){ (*task)(); });
return TaskHandler<return_type>(task);
}
/**
* @brief Queue a task
*
* @param fn Callable
* @param args Arguments for fn
* @tparam Priority
*/
template <TaskPriority Priority = DEFAULT, class FN, class... Args>
void queueTask(FN&& fn, Args&&... args) {
{
std::lock_guard<std::mutex> lock(mTasksMutex);
auto insertPosition = getInsertPosition(Priority, mOrdering);
mTasks.emplace(insertPosition, std::bind(std::forward<FN>(fn), std::forward<Args>(args)...));
}
mTasksCond.notify_one();
}
/**
* @brief Changes the default task ordering for the threadpool
* @param newOrdering
*/
void setDefaultOrdering(TaskOrdering newOrdering) { mOrdering = newOrdering; }
/**
* @brief Change the number of threads managed by the threadpool
* @param int new thread number
*/
void resize(unsigned int newWorkers) {
int workerDiff = newWorkers - mNWorkers;
mNWorkers = newWorkers;
if (workerDiff > 0) { // add new threads
std::lock_guard<std::mutex> lock(mTasksMutex);
for (int i = 0; i < workerDiff; ++i) {
mWorkers.emplace_back(&Threadpool::workerFunction, this);
}
}
else if (workerDiff < 0) { // reduce threads
{
std::lock_guard<std::mutex> lock(mTasksMutex);
// dont assign in case resize is called while there are worker threads left to kill
mThreadsToKill += -workerDiff;
}
// 'mThreadsToKill' not used on purpose to prevent data races
// when threads get notified while there are threads to notify yet
for (int i = 0; i < -workerDiff; ++i) {
mTasksCond.notify_one();
}
}
}
/**
* @brief Blocks calling thread until there are no tasks left to execute
*/
void waitForTasks() {
std::unique_lock<std::mutex> lock(mWaitMutex);
if (mTasks.size() > 0) {
mWaitCond.wait(lock, [this] { return mTasks.size() == 0; });
}
}
private:
Threadpool(const Threadpool&) = delete;
enum WorkerAction {
WAIT,
WORK,
EXIT,
SUICIDE
};
WorkerAction getNextAction() {
if (mPoolDestroyed) return EXIT;
if (mThreadsToKill > 0) return SUICIDE;
if (!mTasks.empty()) return WORK;
return WAIT;
}
std::list<std::function<void()>>::const_iterator
getInsertPosition(TaskPriority priority, TaskOrdering ordering) {
if (priority == MAX) return mTasks.cbegin();
if (priority == MIN) return mTasks.cend();
return (ordering == LIFO) ? mTasks.cbegin() : mTasks.cend();
}
void workerFunction() {
WorkerAction nextAction;
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(mTasksMutex);
nextAction = getNextAction();
if (nextAction == WAIT) {
mTasksCond.wait(lock, [this, &nextAction] {return (nextAction = getNextAction()) != WAIT; });
}
if (nextAction == EXIT) {
break;
}
else if (nextAction == SUICIDE) {
assert(mThreadsToKill > 0);
--mThreadsToKill;
auto threadIt = std::find_if(std::begin(mWorkers), std::end(mWorkers), [](const std::thread& thread) {
return thread.get_id() == std::this_thread::get_id();
});
assert(threadIt == std::begin(mWorkers) || threadIt != std::end(mWorkers));
threadIt->detach();
mWorkers.erase(threadIt);
break;
}
else if (nextAction == WORK) {
// Lock in case some thread is waiting for all tasks to complete
std::lock_guard<std::mutex> lock(mWaitMutex);
task = std::move(mTasks.front());
mTasks.pop_front();
}
}
assert(nextAction == WORK);
// The task itself is not performed under the lock
task();
// Notify in case some thread is waiting for all tasks to complete
mWaitCond.notify_one();
}
}
// The following resources are shared between main thread and worker threads,
// so they need to have its access synchronized:
// - mWorkers
// - mTasks
// - mThreadsToKill
// - mNWorkers;
// - nPoolDestroyed;
// list used to prevent iterator invalidation
// when adding or removing threads
std::list<std::thread> mWorkers;
std::list<std::function<void()>> mTasks;
// sync threads with the task list
std::condition_variable mTasksCond;
std::mutex mTasksMutex;
// allow a thread to wait for all remaining tasks to complete
std::condition_variable mWaitCond;
std::mutex mWaitMutex;
TaskOrdering mOrdering;
// holds the number of threads that need to die as a result of calling
// Threadpool::resize
unsigned int mThreadsToKill;
// cache worker count to prevent data races in case Threadpool::resize
// is called multiple times quickly
unsigned int mNWorkers;
bool mPoolDestroyed;
};
/**
* @brief Handles a task state and allows to get the task result
*
* @tparam Result type of the task result
*/
template <class Result>
class TaskHandler {
// Users are not allow to construct new tasks handlers
template <Threadpool::TaskPriority Priority, class FN, class... Args>
friend TaskHandler<typename std::result_of<FN(Args...)>::type>
Threadpool::queueAndHandleTask(FN&&, Args&&...);
public:
TaskHandler(TaskHandler&& other) :
mTaskHandle(std::move(other.mTaskHandle)),
mResult(std::move(other.mResult))
{}
TaskHandler& operator=(TaskHandler&& other) {
mTaskHandle = std::move(other.mTaskHandle);
mResult = std::move(other.mResult);
return *this;
}
/**
* @brief Returns wheter tha task has completed
* @return bool
*/
bool resultReady() const {
return mResult.wait_for(std::chrono::seconds(0)) == std::future_status::ready;
}
/**
* @brief Get the task result
* @details Blocks the calling thread until the result is ready
* Users are responsible for calling this method only once
* @return The value the task returned
*/
Result get() { return mResult.get(); }
/**
* @brief Blocks the calling thread until the result is ready
*/
void wait() { mResult.wait(); }
private:
explicit TaskHandler(const std::shared_ptr<std::packaged_task<Result()>>& task) :
mTaskHandle(task),
mResult(task->get_future())
{}
TaskHandler(const TaskHandler&) = delete;
std::shared_ptr <std::packaged_task<Result()>> mTaskHandle;
std::future<Result> mResult;
};
}
Examples - main.cpp
// Threadpool.cpp : Defines the entry point for the console application.
//
#include "stdafx.h"
#include "Threadpool.h"
#include <iostream>
#include <memory>
using namespace threadpool;
struct SampleFunctor {
SampleFunctor(int a) : data(a) {}
int operator()() { return data; }
void expensiveMethod() {
std::this_thread::sleep_for(std::chrono::seconds(5));
++data;
}
int data;
};
std::unique_ptr<SampleFunctor> makeFunctor(int data) {
return std::make_unique<SampleFunctor>(data);
}
class SomeManager {
public:
static SomeManager& getInstance()
{
static SomeManager instance;
std::this_thread::sleep_for(std::chrono::seconds(3));
return instance;
}
private:
SomeManager() {}
};
int main()
{
Threadpool threadpool(4);
{
// Queue free functions
auto handler = threadpool.queueAndHandleTask(makeFunctor, 4);
auto result = handler.get();
std::cout << "Result is " << result->data << "\n";
}
// Queue a lambda
threadpool.queueTask([]() {
int counter = 0;
while (counter < 1000) ++counter;
});
{
// Queue functors
auto handler = threadpool.queueAndHandleTask(SampleFunctor{ 20 });
auto result = handler.get();
std::cout << "Functor example: " << result << "\n";
}
{
// Queue member methods
// You can handle void tasks too!
SampleFunctor f{ 20 };
auto handler = threadpool.queueAndHandleTask(&SampleFunctor::expensiveMethod, &f);
// note that this blocks the calling thread
handler.wait();
// you can use handler.get() too (if you dont assign it to anything)
std::cout << "Expensive method done! Data should be 21 -> " << f.data << "\n";
}
{
// Queue static class methods
auto handler = threadpool.queueAndHandleTask(&SomeManager::getInstance);
SomeManager& managerSingleton = handler.get();
std::cout << "Our singleton is located at " << &managerSingleton << "\n";
}
return 0;
}
Extended usage
Additional features:
- Adding/removing threads in runtime
- Getting tasks results and waiting for them to be ready
- Task order and priority
You can find the complete documentation in the github repo.
Questions
My main concerns are:
- Does the interface seem easy to use?
- Is there some edge-case the threadpool doesn't handle?
- Did I follow some anti-pattern I don't know of?
1 Answer 1
This is quite nice, there are not many thread pool implementations that actually take care of dealing with the return value of the enqueued tasks.
Split off the thread-safe queue into its own class
A significant part of your code is having a queue that multiple threads can add/remove tasks to/from. Consider splitting the queue part off into its own class. This will simplify class Threadpool
.
Return a std::future
instead of a TaskHandler
The TaskHandler
class is not really necessary. For the caller, the only thing that is necessary is to wait for a result and to get it. This functionality is already provided by std::future
, you are just adding a wrapper around it. You also don't need to store a shared pointer to std::packaged_task
; std::future
is enough to capture the result of the enqueued task. This also brings me to:
Simplify storing the tasks in the queue... in C++23
It's surprisingly hard to make your code any cleaner. Ideally you don't want to use std::shared_ptr
to store the std::packaged_task
, and std::packaged_task
itself is also not so useful; I think it would be better to use a std::promise
instead. Your code uses std::bind()
multiple times and uses a lambda with a capture, and each of those things (including std::shared_ptr
and std::packaged_task
) can allocate memory dynamically. Ideally you avoid as much of that as possible. In C++23 you will be able to write:
std::list<std::move_only_function<void()>> mTasks;
⋮
template <TaskPriority Priority = DEFAULT, class FN, class... Args>
auto queueAndHandleTask(FN&& fn, Args&&... args)
{
// Prepare a promise and future pair
using return_type = std::invoke_result_t<FN, Args...>;
std::promise<return_type> promise;
auto future = promise.get_future();
// Get the insertion point
std::lock_guard lock(mTasksMutex);
auto insertPosition = getInsertPosition(Priority, mOrdering);
// Insert a lambda that will call fn and sets the promise
mTasks.insert(insertPosition,
[promise = std::move(promise),
fn = std::forward<FN>(fn),
...args = std::forward<Args>(args)] mutable {
promise.set_value(std::invoke(std::forward<FN>(fn), std::forward<Args>(args)...));
});
// Return the future to the caller
return future;
}
This relies on:
- move capture (C++14)
- parameter pack capture (C++20)
- move-only functions (C++23)
You might be able to emulate all this in C++11, but it would result in a lot more code than you have. The advantage however is a lot less memory allocations under the hood, although there are still three: one for the list entry, one for the lambda capture, and one for std::promise
.
Explore related questions
See similar questions with these tags.
TaskHandler
cannot stop a task. Aninterrupt
method could e.g. set a TLS atomic value to true, so a worker function could check whether it's interrupted and return. \$\endgroup\$