15
\$\begingroup\$

I have developed a threadpool using the standard c++11 features and am looking for feedback.

Right now I think the implementation is pretty solid, but as I recently learned about multithreading I'm not sure if I got a solid grasp on it.

Code

Threadpool.h

#pragma once
#include <thread>
#include <future>
#include <condition_variable>
#include <mutex>
#include <list>
#include <memory>
#include <cassert>
#include <algorithm>
namespace threadpool {
 template <class Result> class TaskHandler;
 class Threadpool {
 public:
 // Defines how the tasks will be ordered by default
 enum TaskOrdering { 
 FIFO, // First In First OUT
 LIFO // Last In First OUT
 };
 // Priority with which a task will be queued
 enum TaskPriority { 
 MAX, // The task will be the next one to execute
 DEFAULT, // FIFO or LIFO
 MIN // The task will be put past all stored tasks
 };
 /**
 * @brief Create a threadpool holding a certain number of threads
 * 
 * @param int number of workers
 * @param ordering default ordering function for tasks
 * FIFO = First In First Out
 * LIFO = Last In First Out
 */
 explicit Threadpool(unsigned int nWorkers, TaskOrdering ordering = FIFO) :
 mOrdering(ordering),
 mThreadsToKill(0),
 mNWorkers(nWorkers),
 mPoolDestroyed(false)
 {
 for (unsigned int i = 0; i < nWorkers; ++i) {
 mWorkers.emplace_back(&Threadpool::workerFunction, this);
 }
 }
 /**
 * @brief Join all worker threads, leaving all remaining tasks undone
 */
 ~Threadpool() {
 {
 std::lock_guard<std::mutex> lock(mTasksMutex);
 mPoolDestroyed = true;
 // cancel possible thread suicides
 mThreadsToKill = 0;
 }
 mTasksCond.notify_all();
 for (auto it = std::begin(mWorkers); it != std::end(mWorkers); ++it) {
 it->join();
 }
 }
 /**
 * @brief Queue a task and return a handle to it
 * 
 * @param fn Callable
 * @param args Arguments for fn
 * @tparam Priority
 * @return Handle to the task
 */
 template <TaskPriority Priority = DEFAULT, class FN, class... Args>
 auto queueAndHandleTask(FN&& fn, Args&&... args)
 -> TaskHandler<typename std::result_of<FN(Args...)>::type>
 {
 using return_type = typename std::result_of<FN(Args...)>::type;
 auto task = std::make_shared<std::packaged_task<return_type()>>(
 std::bind(std::forward<FN>(fn), std::forward<Args>(args)...)
 );
 queueTask<Priority>([task](){ (*task)(); });
 return TaskHandler<return_type>(task);
 }
 /**
 * @brief Queue a task
 * 
 * @param fn Callable
 * @param args Arguments for fn
 * @tparam Priority 
 */
 template <TaskPriority Priority = DEFAULT, class FN, class... Args>
 void queueTask(FN&& fn, Args&&... args) {
 {
 std::lock_guard<std::mutex> lock(mTasksMutex);
 auto insertPosition = getInsertPosition(Priority, mOrdering);
 mTasks.emplace(insertPosition, std::bind(std::forward<FN>(fn), std::forward<Args>(args)...));
 }
 mTasksCond.notify_one();
 }
 /**
 * @brief Changes the default task ordering for the threadpool
 * @param newOrdering 
 */
 void setDefaultOrdering(TaskOrdering newOrdering) { mOrdering = newOrdering; }
 /**
 * @brief Change the number of threads managed by the threadpool
 * @param int new thread number
 */
 void resize(unsigned int newWorkers) {
 int workerDiff = newWorkers - mNWorkers;
 mNWorkers = newWorkers;
 if (workerDiff > 0) { // add new threads
 std::lock_guard<std::mutex> lock(mTasksMutex);
 for (int i = 0; i < workerDiff; ++i) {
 mWorkers.emplace_back(&Threadpool::workerFunction, this);
 }
 }
 else if (workerDiff < 0) { // reduce threads
 {
 std::lock_guard<std::mutex> lock(mTasksMutex);
 // dont assign in case resize is called while there are worker threads left to kill
 mThreadsToKill += -workerDiff;
 }
 // 'mThreadsToKill' not used on purpose to prevent data races
 // when threads get notified while there are threads to notify yet
 for (int i = 0; i < -workerDiff; ++i) {
 mTasksCond.notify_one();
 }
 }
 }
 /**
 * @brief Blocks calling thread until there are no tasks left to execute
 */
 void waitForTasks() {
 std::unique_lock<std::mutex> lock(mWaitMutex);
 if (mTasks.size() > 0) {
 mWaitCond.wait(lock, [this] { return mTasks.size() == 0; });
 }
 }
 private:
 Threadpool(const Threadpool&) = delete;
 enum WorkerAction {
 WAIT,
 WORK,
 EXIT,
 SUICIDE
 };
 WorkerAction getNextAction() {
 if (mPoolDestroyed) return EXIT;
 if (mThreadsToKill > 0) return SUICIDE;
 if (!mTasks.empty()) return WORK;
 return WAIT;
 }
 std::list<std::function<void()>>::const_iterator
 getInsertPosition(TaskPriority priority, TaskOrdering ordering) {
 if (priority == MAX) return mTasks.cbegin();
 if (priority == MIN) return mTasks.cend();
 return (ordering == LIFO) ? mTasks.cbegin() : mTasks.cend();
 }
 void workerFunction() {
 WorkerAction nextAction;
 while (true) {
 std::function<void()> task;
 {
 std::unique_lock<std::mutex> lock(mTasksMutex);
 nextAction = getNextAction();
 if (nextAction == WAIT) {
 mTasksCond.wait(lock, [this, &nextAction] {return (nextAction = getNextAction()) != WAIT; });
 }
 if (nextAction == EXIT) {
 break;
 }
 else if (nextAction == SUICIDE) {
 assert(mThreadsToKill > 0);
 --mThreadsToKill;
 auto threadIt = std::find_if(std::begin(mWorkers), std::end(mWorkers), [](const std::thread& thread) {
 return thread.get_id() == std::this_thread::get_id();
 });
 assert(threadIt == std::begin(mWorkers) || threadIt != std::end(mWorkers));
 threadIt->detach();
 mWorkers.erase(threadIt);
 break;
 }
 else if (nextAction == WORK) {
 // Lock in case some thread is waiting for all tasks to complete
 std::lock_guard<std::mutex> lock(mWaitMutex);
 task = std::move(mTasks.front());
 mTasks.pop_front();
 }
 }
 assert(nextAction == WORK);
 // The task itself is not performed under the lock
 task();
 // Notify in case some thread is waiting for all tasks to complete
 mWaitCond.notify_one();
 }
 }
 // The following resources are shared between main thread and worker threads,
 // so they need to have its access synchronized:
 // - mWorkers
 // - mTasks
 // - mThreadsToKill
 // - mNWorkers;
 // - nPoolDestroyed;
 // list used to prevent iterator invalidation
 // when adding or removing threads
 std::list<std::thread> mWorkers;
 std::list<std::function<void()>> mTasks;
 // sync threads with the task list 
 std::condition_variable mTasksCond;
 std::mutex mTasksMutex;
 // allow a thread to wait for all remaining tasks to complete
 std::condition_variable mWaitCond;
 std::mutex mWaitMutex;
 TaskOrdering mOrdering;
 // holds the number of threads that need to die as a result of calling
 // Threadpool::resize
 unsigned int mThreadsToKill;
 // cache worker count to prevent data races in case Threadpool::resize 
 // is called multiple times quickly
 unsigned int mNWorkers;
 bool mPoolDestroyed;
 };
 /**
 * @brief Handles a task state and allows to get the task result
 *
 * @tparam Result type of the task result
 */
 template <class Result>
 class TaskHandler {
 // Users are not allow to construct new tasks handlers
 template <Threadpool::TaskPriority Priority, class FN, class... Args>
 friend TaskHandler<typename std::result_of<FN(Args...)>::type>
 Threadpool::queueAndHandleTask(FN&&, Args&&...);
 public:
 TaskHandler(TaskHandler&& other) :
 mTaskHandle(std::move(other.mTaskHandle)),
 mResult(std::move(other.mResult))
 {}
 TaskHandler& operator=(TaskHandler&& other) {
 mTaskHandle = std::move(other.mTaskHandle);
 mResult = std::move(other.mResult);
 return *this;
 }
 /**
 * @brief Returns wheter tha task has completed
 * @return bool
 */
 bool resultReady() const {
 return mResult.wait_for(std::chrono::seconds(0)) == std::future_status::ready;
 }
 /**
 * @brief Get the task result
 * @details Blocks the calling thread until the result is ready
 * Users are responsible for calling this method only once
 * @return The value the task returned
 */
 Result get() { return mResult.get(); }
 /**
 * @brief Blocks the calling thread until the result is ready
 */
 void wait() { mResult.wait(); }
 private:
 explicit TaskHandler(const std::shared_ptr<std::packaged_task<Result()>>& task) :
 mTaskHandle(task),
 mResult(task->get_future())
 {}
 TaskHandler(const TaskHandler&) = delete;
 std::shared_ptr <std::packaged_task<Result()>> mTaskHandle;
 std::future<Result> mResult;
 };
}

Examples - main.cpp

// Threadpool.cpp : Defines the entry point for the console application.
//
#include "stdafx.h"
#include "Threadpool.h"
#include <iostream>
#include <memory>
using namespace threadpool;
struct SampleFunctor {
 SampleFunctor(int a) : data(a) {}
 int operator()() { return data; }
 void expensiveMethod() {
 std::this_thread::sleep_for(std::chrono::seconds(5));
 ++data;
 }
 int data;
};
std::unique_ptr<SampleFunctor> makeFunctor(int data) {
 return std::make_unique<SampleFunctor>(data);
}
class SomeManager {
public:
 static SomeManager& getInstance()
 {
 static SomeManager instance;
 std::this_thread::sleep_for(std::chrono::seconds(3));
 return instance;
 }
private:
 SomeManager() {}
};
int main()
{
 Threadpool threadpool(4);
 {
 // Queue free functions
 auto handler = threadpool.queueAndHandleTask(makeFunctor, 4);
 auto result = handler.get();
 std::cout << "Result is " << result->data << "\n";
 }
 // Queue a lambda
 threadpool.queueTask([]() {
 int counter = 0;
 while (counter < 1000) ++counter;
 });
 
 {
 // Queue functors
 auto handler = threadpool.queueAndHandleTask(SampleFunctor{ 20 });
 auto result = handler.get();
 std::cout << "Functor example: " << result << "\n";
 }
 {
 // Queue member methods
 // You can handle void tasks too!
 SampleFunctor f{ 20 };
 auto handler = threadpool.queueAndHandleTask(&SampleFunctor::expensiveMethod, &f);
 // note that this blocks the calling thread
 handler.wait();
 // you can use handler.get() too (if you dont assign it to anything)
 std::cout << "Expensive method done! Data should be 21 -> " << f.data << "\n";
 }
 {
 // Queue static class methods
 auto handler = threadpool.queueAndHandleTask(&SomeManager::getInstance);
 SomeManager& managerSingleton = handler.get();
 std::cout << "Our singleton is located at " << &managerSingleton << "\n";
 }
 return 0;
}

Extended usage

Additional features:

  • Adding/removing threads in runtime
  • Getting tasks results and waiting for them to be ready
  • Task order and priority

You can find the complete documentation in the github repo.

Questions

My main concerns are:

  • Does the interface seem easy to use?
  • Is there some edge-case the threadpool doesn't handle?
  • Did I follow some anti-pattern I don't know of?
Mast
13.8k12 gold badges56 silver badges127 bronze badges
asked Oct 15, 2017 at 16:31
\$\endgroup\$
1
  • 2
    \$\begingroup\$ The TaskHandler cannot stop a task. An interrupt method could e.g. set a TLS atomic value to true, so a worker function could check whether it's interrupted and return. \$\endgroup\$ Commented Jun 5, 2020 at 15:39

1 Answer 1

1
\$\begingroup\$

This is quite nice, there are not many thread pool implementations that actually take care of dealing with the return value of the enqueued tasks.

Split off the thread-safe queue into its own class

A significant part of your code is having a queue that multiple threads can add/remove tasks to/from. Consider splitting the queue part off into its own class. This will simplify class Threadpool.

Return a std::future instead of a TaskHandler

The TaskHandler class is not really necessary. For the caller, the only thing that is necessary is to wait for a result and to get it. This functionality is already provided by std::future, you are just adding a wrapper around it. You also don't need to store a shared pointer to std::packaged_task; std::future is enough to capture the result of the enqueued task. This also brings me to:

Simplify storing the tasks in the queue... in C++23

It's surprisingly hard to make your code any cleaner. Ideally you don't want to use std::shared_ptr to store the std::packaged_task, and std::packaged_task itself is also not so useful; I think it would be better to use a std::promise instead. Your code uses std::bind() multiple times and uses a lambda with a capture, and each of those things (including std::shared_ptr and std::packaged_task) can allocate memory dynamically. Ideally you avoid as much of that as possible. In C++23 you will be able to write:

std::list<std::move_only_function<void()>> mTasks;
⋮
template <TaskPriority Priority = DEFAULT, class FN, class... Args>
auto queueAndHandleTask(FN&& fn, Args&&... args)
{
 // Prepare a promise and future pair
 using return_type = std::invoke_result_t<FN, Args...>;
 std::promise<return_type> promise;
 auto future = promise.get_future();
 // Get the insertion point
 std::lock_guard lock(mTasksMutex);
 auto insertPosition = getInsertPosition(Priority, mOrdering);
 // Insert a lambda that will call fn and sets the promise
 mTasks.insert(insertPosition,
 [promise = std::move(promise),
 fn = std::forward<FN>(fn),
 ...args = std::forward<Args>(args)] mutable {
 promise.set_value(std::invoke(std::forward<FN>(fn), std::forward<Args>(args)...));
 });
 // Return the future to the caller
 return future;
}

This relies on:

  • move capture (C++14)
  • parameter pack capture (C++20)
  • move-only functions (C++23)

You might be able to emulate all this in C++11, but it would result in a lot more code than you have. The advantage however is a lot less memory allocations under the hood, although there are still three: one for the list entry, one for the lambda capture, and one for std::promise.

answered Nov 28, 2022 at 23:17
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.