I just started learning C++, and I'm taking an algorithms course on MIT Open Course Ware to get a better understanding of basic CS and the best practices in C++. Any comments on the structure of the algorithm (optimizations, etc) or best practices in C++ would be greatly appreciated!
#include <iterator>
#include <random>
namespace sort
{
namespace detail
{
template <typename ForwardIter, typename Compare>
ForwardIter randomized_partition(ForwardIter start, ForwardIter end, Compare less_than)
{
static std::default_random_engine gen;
std::uniform_int_distribution<long> rand_index(0, std::distance(start, end) - 1);
ForwardIter pivot = start + rand_index(gen);
std::iter_swap(start, pivot);
pivot = start;
ForwardIter less_than_pivot = start, greater_than_pivot = start + 1;
while (greater_than_pivot != end)
{
if ( less_than(*greater_than_pivot, *pivot) ) {
++less_than_pivot;
std::iter_swap(less_than_pivot, greater_than_pivot);
}
++greater_than_pivot;
}
std::iter_swap(pivot, less_than_pivot);
pivot = less_than_pivot;
return pivot;
}
}
template <typename ForwardIter, typename Compare = std::less<typename std::iterator_traits<ForwardIter>::value_type>>
void randomized_quick_sort(ForwardIter start, ForwardIter end, Compare less_than = Compare())
{
if (std::distance(start, end) > 1)
{
ForwardIter pivot = sort::detail::randomized_partition(start, end, less_than);
randomized_quick_sort(start, pivot, less_than);
randomized_quick_sort(pivot + 1, end, less_than);
}
}
}
1 Answer 1
There's an inefficiency here:
if (std::distance(start, end) > 1)
For a forward iterator, std::distance()
may be linear in the distance. Since we only care whether the distance is 0 or 1 we can test for those two cases:
if (start == end or std::next(start) == end) {
// 0 or 1 elements - must be sorted!
return;
}
Note that real quicksort implementations switch to a different algorithm such as selection sort when the number of elements is low - but we could still write a bounded distance check (in the same way that some platforms have a strnlen()
function).
// untested
template <std::forward_iterator ForwardIter>
bool distance_le(ForwardIter start,
std::sentinel_for<ForwardIter> auto end,
std::size_t limit)
{
while (limit-- != 0) {
if (start == end) {
return true;
}
std::advance(start, 1);
}
return false;
}
We have another use of std::distance()
in randomized_partition()
. If we want to eliminate that, we'll want to measure distance in the outermost function, then pass the correct distances around to its callees.
ForwardIter pivot = start + rand_index(gen);
We can't assume operator+
is defined for a forward iterator:
ForwardIter pivot = start;
std::advance(pivot, rand_index(gen));
rand_index()
generates long
, but we really need it to generate std::size_t
:
auto const len = std::distance(start, end);
assert(len > 0); // because randomized_quick_sort() won't call us if so
std::uniform_int_distribution rand_index{0z, static_cast<std::size_t>(len - 1)};
We have tail recursion here:
randomized_quick_sort(start, pivot, less_than); randomized_quick_sort(pivot + 1, end, less_than); } } }
Because C++ compilers are not required to perform tail call elimination, it may be wise to transform this to iterative form. Sort the smaller partition recursively (to minimise call-stack depth), then adjust start
or end
and loop back to the beginning.
-
1\$\begingroup\$ "Because C++ compilers are not required to perform tail call elimination," But they all do, so I wouldn't make the source code more complex by making it iterative. \$\endgroup\$G. Sliepen– G. Sliepen2024年09月01日 18:06:59 +00:00Commented Sep 1, 2024 at 18:06
-
1\$\begingroup\$ Still a good idea to do the smallest sort in the recursive call and the larger one in the tail call, even when tail-call elimination happens. \$\endgroup\$Toby Speight– Toby Speight2024年09月01日 19:07:47 +00:00Commented Sep 1, 2024 at 19:07
less_than
since people might want to sort by a different key. I thinkstd::sort
usescomp
as its parameter name. \$\endgroup\$operator+
, sostart + 1
should be replaced withstd::next(start)
. The same goes forstart + rand_index(gen);
. \$\endgroup\$static_assert
to ensure the passed iterator's category inherited fromstd::forward_iterator_tag
. Anything else that I could improve on? \$\endgroup\$