9
\$\begingroup\$

I did wrote this parallel merge sort with no external dependencies. I tried to make use of modern C++ as much as possible.
Can you please tell me how it is?

#include <vector>
#include <list>
#include <thread>
#include <memory>
#include <mutex>
#include <condition_variable>
#include <algorithm>
#include <utility>
#include <exception>
#include <cassert>
#include <iterator>
template<typename T>
struct invoke_on_destruct
{
private:
 T &m_t;
 bool m_enabled;
public:
 invoke_on_destruct( T &t ) :
 m_t( t ), m_enabled( true )
 {
 }
 ~invoke_on_destruct()
 {
 if( m_enabled )
 m_t();
 }
 void invoke_and_disable()
 {
 m_t();
 m_enabled = false;
 }
};
struct sort_exception : public std::exception
{
};
template<typename InputIt, typename P = std::less<typename std::iterator_traits<InputIt>::value_type>>
class parallel_merge_sort
{
public:
 parallel_merge_sort( P const &p = P() );
 ~parallel_merge_sort();
 void sort( InputIt itBegin, size_t n, std::size_t minThreaded );
 std::size_t get_buffer_size();
 void empty_buffers();
private:
 typedef typename std::iterator_traits<InputIt>::value_type value_type;
 typedef typename std::vector<value_type> buffer_type;
 typedef typename buffer_type::iterator buffer_iterator;
 struct pool_thread
 {
 enum CMD : int { CMD_STOP = -1, CMD_NONE = 0, CMD_SORT = 1 };
 enum RSP : int { RSP_ERR = -1, RSP_NONE = 0, RSP_SUCCESS = 1 };
 std::thread m_thread;
 std::mutex m_mtx;
 std::condition_variable m_sigInitiate;
 CMD m_cmd;
 buffer_iterator m_itBegin;
 std::size_t m_n;
 std::condition_variable m_sigResponse;
 RSP m_rsp;
 std::vector<value_type> m_sortBuf;
 pool_thread( parallel_merge_sort *pPMS );
 ~pool_thread();
 void sort_thread( parallel_merge_sort *pPMS );
 static std::size_t calc_buffer_size( size_t n );
 };
 P m_p;
 std::size_t m_minThreaded;
 unsigned m_maxRightThreads;
 buffer_type m_callerSortBuf;
 std::mutex m_mtxPool;
 std::list<pool_thread> m_standbyThreads;
 std::list<pool_thread> m_activeThreads;
 template<typename InputIt2>
 void threaded_sort( InputIt2 itBegin, std::size_t n, buffer_iterator itSortBuf );
 template<typename InputIt2>
 void unthreaded_sort( InputIt2 itBegin, std::size_t n, buffer_iterator itSortBuf );
 template<typename OutputIt>
 void merge_back( OutputIt itUp, buffer_iterator itLeft, buffer_iterator itLeftEnd, buffer_iterator itRight, buffer_iterator itRightEnd );
};
template<typename InputIt, typename P>
parallel_merge_sort<InputIt, P>::parallel_merge_sort( P const &p ) :
 m_p( p )
{
 unsigned hc = std::thread::hardware_concurrency();
 m_maxRightThreads = hc != 0 ? (hc - 1) : 0;
}
template<typename InputIt, typename P>
void parallel_merge_sort<InputIt, P>::sort( InputIt itBegin, size_t n, std::size_t minThreaded )
{
 size_t const MIN_SIZE = 2;
 if( n < MIN_SIZE )
 return;
 if( (m_minThreaded = minThreaded) < (2 * MIN_SIZE) )
 m_minThreaded = 2 * MIN_SIZE;
 try
 {
 std::size_t s = pool_thread::calc_buffer_size( n );
 if( m_callerSortBuf.size() < s )
 m_callerSortBuf.resize( s );
 threaded_sort( itBegin, n, m_callerSortBuf.begin() );
 }
 catch( ... )
 {
 throw sort_exception();
 }
}
template<typename InputIt, typename P>
parallel_merge_sort<InputIt, P>::~parallel_merge_sort()
{
 assert(m_activeThreads.size() == 0);
}
template<typename InputIt, typename P>
inline
std::size_t parallel_merge_sort<InputIt, P>::pool_thread::calc_buffer_size( std::size_t n )
{
 for( std::size_t rest = n, right; rest > 2; )
 right = rest - (rest / 2),
 n += right,
 rest = right;
 return n;
}
template<typename InputIt, typename P>
parallel_merge_sort<InputIt, P>::pool_thread::~pool_thread()
{
 using namespace std;
 unique_lock<mutex> threadLock( m_mtx );
 m_cmd = pool_thread::CMD_STOP;
 m_sigInitiate.notify_one();
 threadLock.unlock();
 m_thread.join();
}
template<typename InputIt, typename P>
template<typename InputIt2>
void parallel_merge_sort<InputIt, P>::threaded_sort( InputIt2 itBegin, std::size_t n, buffer_iterator itSortBuf )
{
 using namespace std;
 unique_lock<mutex> poolLock( m_mtxPool );
 if( n < m_minThreaded || (m_standbyThreads.empty() && m_activeThreads.size() >= m_maxRightThreads) )
 {
 poolLock.unlock();
 unthreaded_sort( itBegin, n, itSortBuf );
 return;
 }
 typedef typename list<pool_thread>::iterator pt_it;
 pt_it itPT;
 pool_thread *pPT;
 size_t left = n / 2,
 right = n - left;
 if( !m_standbyThreads.empty() )
 {
 pt_it itPTScan;
 size_t optimalSize = pool_thread::calc_buffer_size( right ),
 bestFit = (size_t)(ptrdiff_t)-1,
 size;
 for( itPT = m_standbyThreads.end(), itPTScan = m_standbyThreads.begin();
 itPTScan != m_standbyThreads.end(); ++itPTScan )
 if( (size = itPTScan->m_sortBuf.size()) >= optimalSize && size < bestFit )
 itPT = itPTScan,
 bestFit = size;
 if( itPT == m_standbyThreads.end() )
 itPT = --m_standbyThreads.end();
 m_activeThreads.splice( m_activeThreads.end(), m_standbyThreads, itPT );
 poolLock.unlock();
 pPT = &*itPT;
 }
 else
 m_activeThreads.emplace_back( this ),
 itPT = --m_activeThreads.end(),
 pPT = &*itPT,
 poolLock.unlock();
 auto pushThreadBack = [&poolLock, &itPT, this]()
 {
 poolLock.lock();
 m_standbyThreads.splice( m_standbyThreads.end(), m_activeThreads, itPT );
 };
 invoke_on_destruct<decltype(pushThreadBack)> autoPushBackThread( pushThreadBack );
 buffer_iterator itMoveTo = itSortBuf;
 for( InputIt2 itMoveFrom = itBegin, itEnd = itMoveFrom + n; itMoveFrom != itEnd; *itMoveTo = move( *itMoveFrom ), ++itMoveTo, ++itMoveFrom );
 buffer_iterator itLeft = itSortBuf,
 itRight = itLeft + left;
 unique_lock<mutex> threadLock( pPT->m_mtx );
 pPT->m_cmd = pool_thread::CMD_SORT;
 pPT->m_rsp = pool_thread::RSP_NONE;
 pPT->m_itBegin = itRight;
 pPT->m_n = right;
 pPT->m_sigInitiate.notify_one();
 threadLock.unlock();
 auto waitForThread = [&threadLock, pPT]()
 {
 threadLock.lock();
 while( pPT->m_rsp == pool_thread::RSP_NONE )
 pPT->m_sigResponse.wait( threadLock );
 assert(pPT->m_rsp == pool_thread::RSP_SUCCESS || pPT->m_rsp == pool_thread::RSP_ERR);
 };
 invoke_on_destruct<decltype(waitForThread)> autoWaitForThread( waitForThread );
 threaded_sort( itLeft, left, itSortBuf + n );
 autoWaitForThread.invoke_and_disable();
 if( pPT->m_rsp == pool_thread::RSP_ERR )
 throw sort_exception();
 threadLock.unlock();
 merge_back( itBegin, itLeft, itLeft + left, itRight, itRight + right );
}
template<typename InputIt, typename P>
template<typename InputIt2>
void parallel_merge_sort<InputIt, P>::unthreaded_sort( InputIt2 itBegin, std::size_t n, buffer_iterator itSortBuf )
{
 assert(n >= 2);
 using namespace std;
 if( n == 2 )
 {
 if( m_p( itBegin[1], itBegin[0] ) )
 {
 value_type temp( move( itBegin[0] ) );
 itBegin[0] = move( itBegin[1] );
 itBegin[1] = move( temp );
 }
 return;
 }
 buffer_iterator itMoveTo = itSortBuf;
 for( InputIt2 itMoveFrom = itBegin, itEnd = itMoveFrom + n; itMoveFrom != itEnd; *itMoveTo = move( *itMoveFrom ), ++itMoveTo, ++itMoveFrom );
 size_t left = n / 2,
 right = n - left;
 buffer_iterator itLeft = itSortBuf,
 itRight = itLeft + left;
 if( left >= 2 )
 unthreaded_sort( itLeft, left, itSortBuf + n );
 if( right >= 2 )
 unthreaded_sort( itRight, right, itSortBuf + n );
 merge_back( itBegin, itLeft, itLeft + left, itRight, itRight + right );
}
template<typename InputIt, typename P>
template<typename OutputIt>
inline
void parallel_merge_sort<InputIt, P>::merge_back( OutputIt itUp, buffer_iterator itLeft, buffer_iterator itLeftEnd, buffer_iterator itRight, buffer_iterator itRightEnd )
{
 assert(itLeft < itLeftEnd && itRight < itRightEnd);
 using namespace std;
 for( ; ; )
 if( m_p( *itLeft, *itRight ) )
 {
 *itUp = move( *itLeft );
 ++itUp, ++itLeft;
 if( itLeft == itLeftEnd )
 {
 for( ; itRight != itRightEnd; *itUp = move( *itRight ), ++itUp, ++itRight );
 break;
 }
 }
 else
 {
 *itUp = move( *itRight );
 ++itUp, ++itRight;
 if( itRight == itRightEnd )
 {
 for( ; itLeft != itLeftEnd; *itUp = move( *itRight ), ++itUp, ++itLeft );
 break;
 }
 }
}
template<typename InputIt, typename P>
std::size_t parallel_merge_sort<InputIt, P>::get_buffer_size()
{
 std::size_t s = 0;
 for( pool_thread &pt : m_standbyThreads )
 s += pt.m_sortBuf.capacity();
 return s + m_callerSortBuf.capacity();
}
template<typename InputIt, typename P>
void parallel_merge_sort<InputIt, P>::empty_buffers()
{
 for( pool_thread &pt : m_standbyThreads )
 pt.m_sortBuf.clear(),
 pt.m_sortBuf.shrink_to_fit();
 m_callerSortBuf.clear();
 m_callerSortBuf.shrink_to_fit();
}
template<typename InputIt, typename P>
parallel_merge_sort<InputIt, P>::pool_thread::pool_thread( parallel_merge_sort *pPMS ) :
 m_mtx(),
 m_sigInitiate(),
 m_cmd( pool_thread::CMD_NONE ),
 m_thread( std::thread( []( pool_thread *pPT, parallel_merge_sort *pPMS ) -> void { pPT->sort_thread( pPMS ); }, this, pPMS ) )
{
}
template<typename InputIt, typename P>
void parallel_merge_sort<InputIt, P>::pool_thread::sort_thread( parallel_merge_sort *pPMS )
{
 using namespace std;
 for( ; ; )
 {
 unique_lock<mutex> threadLock( m_mtx );
 while( m_cmd == CMD_NONE )
 m_sigInitiate.wait( threadLock );
 if( m_cmd == CMD_STOP )
 return;
 assert(m_cmd == pool_thread::CMD_SORT);
 m_cmd = CMD_NONE;
 threadLock.unlock();
 bool success;
 try
 {
 size_t size = calc_buffer_size( m_n );
 if( m_sortBuf.size() < size )
 m_sortBuf.resize( size );
 pPMS->threaded_sort( m_itBegin, m_n, m_sortBuf.begin() );
 success = true;
 }
 catch( ... )
 {
 success = false;
 }
 threadLock.lock();
 m_rsp = success ? RSP_SUCCESS : RSP_ERR,
 m_sigResponse.notify_one();
 }
}
template<typename InputIt, typename P = std::less<typename std::iterator_traits<InputIt>::value_type>>
class ref_parallel_merge_sort
{
private:
 struct ref
 {
 InputIt it;
 };
 struct ref_predicate
 {
 ref_predicate( P p );
 bool operator ()( ref const &left, ref const &right );
 P m_p;
 };
public:
 ref_parallel_merge_sort( P const &p = P() );
 void sort( InputIt itBegin, size_t n, std::size_t maxUnthreaded );
 std::size_t get_buffer_size();
 void empty_buffers();
private:
 parallel_merge_sort<ref, ref_predicate> m_sorter;
};
template<typename InputIt, typename P>
inline
ref_parallel_merge_sort<InputIt, P>::ref_predicate::ref_predicate( P p ) :
 m_p ( p )
{
}
template<typename InputIt, typename P>
inline
bool ref_parallel_merge_sort<InputIt, P>::ref_predicate::operator ()( ref const &left, ref const &right )
{
 return m_p( *left.it, *right.it );
}
template<typename InputIt, typename P>
inline
ref_parallel_merge_sort<InputIt, P>::ref_parallel_merge_sort( P const &p ) :
 m_sorter( ref_predicate( p ) )
{
}
template<typename InputIt, typename P>
void ref_parallel_merge_sort<InputIt, P>::sort( InputIt itBegin, size_t n, std::size_t maxUnthreaded )
{
 using namespace std;
 try
 {
 typedef typename iterator_traits<InputIt>::value_type value_type;
 vector<ref> refBuf;
 InputIt it;
 int i;
 refBuf.resize( n );
 for( i = 0, it = itBegin; i != n; refBuf[i].it = it, ++i, ++it );
 m_sorter.sort( &refBuf[0], n, maxUnthreaded );
 vector<value_type> reorderBuf;
 reorderBuf.resize( n );
 for( i = 0, it = itBegin; i != n; reorderBuf[i] = move( *it ), ++i, ++it );
 for( i = 0, it = itBegin; i != n; *it = move( reorderBuf[i] ), ++i, ++it );
 }
 catch( ... )
 {
 throw sort_exception();
 }
}
template<typename InputIt, typename P>
inline
std::size_t ref_parallel_merge_sort<InputIt, P>::get_buffer_size()
{
 return m_sorter.get_buffer_size();
}
template<typename InputIt, typename P>
inline
void ref_parallel_merge_sort<InputIt, P>::empty_buffers()
{
 m_sorter.empty_buffers();
}
#include <iostream>
#include <cstdlib>
#include <functional>
#include <random>
#include <cstdint>
#include <iterator>
#include <type_traits>
#if defined(_MSC_VER)
 #include <Windows.h>
double get_usecs()
{
 LONGLONG liTime;
 GetSystemTimeAsFileTime( &(FILETIME &)liTime );
 return (double)liTime / 10.0;
}
#elif defined(__unix__)
 #include <sys/time.h>
double get_usecs()
{
 timeval tv;
 gettimeofday( &tv, nullptr );
 return (double)tv.tv_sec * 1'000'000.0 + tv.tv_usec;
}
#elif
 #error no OS-support for get_usecs()
#endif
using namespace std;
void fill_with_random( double *p, size_t n, unsigned seed = 0 )
{
 default_random_engine re( seed );
 uniform_real_distribution<double> distrib;
 for( double *pEnd = p + n; p != pEnd; *p++ = distrib( re ) );
}
template<typename T, typename = typename enable_if<is_unsigned<T>::value, T>::type>
string decimal_unsigned( T t );
int main()
{
 typedef typename vector<double>::iterator it_type;
 size_t const SIZE = (size_t)1024 * 1024 * 1024 / sizeof(double);
 unsigned hc = thread::hardware_concurrency();
 vector<double> v;
 double t;
 v.resize( SIZE );
 parallel_merge_sort<it_type> sd;
 fill_with_random( &v[0], SIZE );
 t = get_usecs();
 sd.sort( v.begin(), SIZE, SIZE / hc );
 t = get_usecs() - t;
 cout << (t / 1'000'000.0) << " seconds parallel" << endl;
 cout << decimal_unsigned( sd.get_buffer_size() * sizeof(double) ) << endl;
 sd.empty_buffers();
 fill_with_random( &v[0], SIZE );
 t = get_usecs();
 sd.sort( v.begin(), SIZE, SIZE );
 t = get_usecs() - t;
 cout << (t / 1'000'000.0) << " seconds sequential" << endl;
 cout << decimal_unsigned( sd.get_buffer_size() * sizeof(double) ) << endl;
 sd.empty_buffers();
}
#include <sstream>
string decify_string( string const &s );
template<typename T, typename>
string decimal_unsigned( T t )
{
 using namespace std;
 ostringstream oss;
 return move( decify_string( (oss << t, oss.str()) ) );
}
string decify_string( string const &s )
{
 using namespace std;
 ostringstream oss;
 size_t length = s.length(),
 head = length % 3,
 segments = length / 3;
 if( head == 0 && segments >= 1 )
 head = 3,
 --segments;
 oss << s.substr( 0, head );
 for( size_t i = head; i != length; i += 3 )
 oss << "." << s.substr( i, 3 );
 return move( oss.str() );
}
200_success
145k22 gold badges190 silver badges478 bronze badges
asked May 2, 2019 at 17:12
\$\endgroup\$

1 Answer 1

1
\$\begingroup\$

Design-wise, I don't like the tight coupling of the sort and the thread pool. Users expect a sort to be a plain function; it probably makes sense to be able to pass a thread-pool as an argument (and for it to be defaulted). This coupling makes it hard to use the pool for anything else, or to use the sort with a different implementation of the thread pool.


invoke_on_destruct seems more complex than necessary. We never invoke-and-disable more than once, so we could simply replace the stored function with a no-op rather than having a boolean member:

#include <functional>
class invoke_on_destruct
{
 std::function<void()> t;
 
public:
 invoke_on_destruct(std::function<void()> t)
 : t(std::move(t))
 {}
 ~invoke_on_destruct()
 { t(); }
 void invoke_and_disable()
 { t(); t = []{}; }
};

We don't really need invoke_and_disable(): this is the only place it's used:

invoke_on_destruct autoWaitForThread( waitForThread );
threaded_sort( itLeft, left, itSortBuf + n );
autoWaitForThread.invoke_and_disable();

That could reasonably be changed to simply reduce the scope:

{
 invoke_on_destruct autoWaitForThread( waitForThread );
 threaded_sort( itLeft, left, itSortBuf + n );
}

In any case, we need to be very careful that the invoked function does not throw, since exceptions thrown from destructors will terminate the entire program.


Please don't hide exception detail like this:

 catch( ... )
 {
 throw sort_exception();
 }

That loses useful information, and isn't even much convenience to the caller.


A few minor points of improvement:

  • There's a couple of functions that have using namespace std; where a more targeted approach (using std::move;) would be safer.

  • return std::move(x); is an anti-pattern - trust your compiler to elide return-value copying.

  • The whole decify_string() business is a waste of your effort - just go with the user's preferred format, using std::cout.imbue(std::locale("")); early on. If you really need that specific formatting, then use a modified user locale instead.

  • Don't use all-caps names (e.g. SIZE) for identifiers - it's conventional to reserve these names for macros, and the "shoutiness" warns that they should be treated with care. That warning is diluted if we use it for non-macro names too.

  • Don't use std::endl where a flush is not required. In fact, it's a good idea to never use it, and explicitly write both \n and std::flush where you really want the effect of std::endl.

  • Fix the misleading indentation:

     for( pool_thread &pt : m_standbyThreads )
     pt.m_sortBuf.clear(),
     pt.m_sortBuf.shrink_to_fit();
     m_callerSortBuf.clear();
    
answered Nov 6, 2020 at 12:17
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.