Another look at finding median value. This time the input is a histogram represented as an ordered range of (value, count) pairs, or simply as a range of counts, with value inferred from position. I've taken care to avoid integer overflow in the computation, which unfortunately complicates the code more than I would like.
There's a simpler implementation for bidirectional ranges, and a more general version for when the total count is known (this is useful for image histograms, where the total is width ✕ height).
I have provided an adapter to ensure counts are in the range [0, +∞). This is to be explicitly applied by user, rather than built into the functions, so that it imposes no penalty on code that doesn't use it. Other features I might add in future, but not provided here, are:
- Find other percentiles, not just median.
- Support "wide" bins as well as single values.
- Validation adapter to check that the values are ordered.
- Validation of known-total computation to ensure that the total is correct.
#include <cmath>
#include <concepts>
#include <iterator>
#include <numeric>
#include <stdexcept>
#include <ranges>
#include <tuple>
namespace median
{
// Median result types - allow indication of integer-and-a-half
template<typename T>
struct result {};
template<std::floating_point T>
struct result<T> {
T value;
result(T mid) : value{mid} {}
result(T left, T right) : value{std::midpoint(left, right)} {}
operator T() const { return value; }
auto as_double() const { return value; }
};
template<std::integral T>
struct result<T> {
T whole_part;
bool plus_half;
result(T mid)
: whole_part{mid},
plus_half{false}
{}
result(T left, T right)
: whole_part{std::midpoint(left, right)},
plus_half{left % 2 != right % 2}
{
if (plus_half && right < left) { --whole_part; }
}
auto as_double() const { return static_cast<double>(whole_part) + 0.5 * plus_half; }
explicit operator T() const { return whole_part; }
explicit operator double() const { return as_double(); }
};
// Range adapters for verifying range values
auto const checked_histogram = std::views::transform([](auto&& val) {
if constexpr (requires{ std::get<1>(val); }) {
auto const& [i,count] = val;
if (count < 0 || !std::isfinite(count)) {
throw std::domain_error("invalid histogram entry");
}
} else {
if (val < 0 || !std::isfinite(val)) {
throw std::domain_error("invalid histogram entry");
}
}
return val;
});
// A histogram is an ordered range of {value, count} pairs.
// Alternatively, a range of counts can be provided and the
// values 0, 1, 2, ... will be inferred.
// Median of a bidirectional histogram
template<std::ranges::bidirectional_range R>
auto from_histogram(R const& input)
-> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
requires requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
{
auto left = input.cbegin();
auto right = input.cend();
if (left == right) {
throw std::domain_error("empty histogram");
}
auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };
auto constexpr has_positive_count = [](auto const& pair){ return std::get<1>(pair) > 0; };
auto left_sum = count(left);
auto right_sum = count(--right);
while (left != right) {
// Reduce sums so that at least one of them is zero
if (left_sum > right_sum) {
left_sum -= right_sum;
right_sum = 0;
} else {
right_sum -= left_sum;
left_sum = 0;
}
// advance one of the iterators
if (left_sum) {
right_sum += count(--right);
} else if (right_sum) {
left_sum += count(++left);
} else {
// left and right sums both zero
auto const it = std::find_if(std::next(left), right, has_positive_count);
if (it == right) {
return {value(left), value(right)};
}
left_sum += count(left = it);
}
}
return value(left);
}
// Median of a forward-only histogram
template<std::ranges::forward_range R>
auto from_histogram(R const& input)
-> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
requires (not std::ranges::bidirectional_range<R>)
and requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
{
auto constexpr has_positive_count = [](auto const& pair){ return std::get<1>(pair) > 0; };
auto left = input.cbegin();
auto right = left;
auto const last = input.cend();
if (right == last) {
throw std::domain_error("empty histogram");
}
auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };
auto left_sum = count(left) + 0; // addition promotes smaller types to signed/unsigned int
auto right_sum = count(right) + 0;
auto constexpr addition_would_overflow = [](auto augend, auto addend)
{
// Neither argument is negative.
return addend > std::numeric_limits<decltype(augend)>::max() - augend;
};
auto constexpr reduce = [](auto& left_sum, auto &right_sum) {
// // Reduce sums
if (left_sum > 1 && right_sum > 2) {
auto subtrahend = std::min(left_sum - 1, (right_sum - 1) / 2);
left_sum -= subtrahend;
right_sum -= subtrahend * 2;
}
};
using std::next;
while (next(right) != last) {
reduce(left_sum, right_sum);
{
// advance right
auto right_addend = count(++right);
reduce(left_sum, right_addend);
while (addition_would_overflow(right_sum, right_addend)) {
auto left_addend = count(++left);
reduce(left_addend, right_addend);
left_sum += left_addend;
if (left == right) { break; }
}
right_sum += right_addend;
}
while (!addition_would_overflow(left_sum, left_sum) && left_sum + left_sum < right_sum) {
// advance left until it reaches right/2
auto left_addend = count(++left);
reduce(left_addend, right_sum);
while (addition_would_overflow(left_sum, left_addend)) {
auto right_addend = count(++right);
reduce(left_addend, right_addend);
right_sum += right_addend;
if (next(right) == last) {
break;
}
}
left_sum += left_addend;
}
}
if (left_sum * 2 == right_sum) {
// tie break
if (left == right) [[unlikely]] {
// only happens with {0} as input
return value(left);
}
auto const it = std::ranges::find_if(std::next(left), right, has_positive_count);
return {value(left), value(it)};
}
return value(left);
}
template<std::ranges::forward_range R>
auto from_histogram(R const& input)
requires std::assignable_from<double&, std::ranges::range_value_t<R>>
{
return median::from_histogram(input | std::views::enumerate);
}
// Median of any histogram, when total population is known
template<typename T, std::ranges::forward_range R>
auto from_histogram(T const& total, R const& input)
-> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
requires requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
{
auto iter = input.cbegin();
auto const last = input.cend();
if (iter == last) {
throw std::domain_error("empty histogram");
}
auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };
if (total == 0) [[unlikely]] {
// Try to return something sensible
auto const left_val = value(iter++);
auto right_val = left_val;
while (iter != last) {
right_val = value(iter++);
}
return {left_val, right_val};
}
T sum = 0;
do {
auto addend = count(iter);
if (addend > total - sum) {
throw std::domain_error("overpopulated histogram");
}
sum += count(iter);
if (sum > total / 2) {
return value(iter);
}
if (2 * sum == total) {
// find midpoint
auto const left_val = value(iter);
iter = std::ranges::find_if(++iter, last, [](auto val){ return std::get<1>(val) != 0; });
if (iter == last) { break; }
return {left_val, value(iter)};
}
} while (++iter != last);
// ran off the end
throw std::domain_error("underpopulated histogram");
}
template<std::ranges::input_range R, std::convertible_to<std::ranges::range_value_t<R>> T>
auto from_histogram(T total, R const& input)
{
return median::from_histogram(total, input | std::views::enumerate);
}
}
I used these tests to write the preceding functions:
#include <gtest/gtest.h>
#include <climits>
#include <forward_list>
#include <map>
#include <vector>
TEST(vector_input, empty)
{
std::vector<int> empty;
EXPECT_THROW(median::from_histogram(empty), std::domain_error);
}
TEST(vector_input, one_element)
{
std::vector<int> hist{0};
EXPECT_EQ(median::from_histogram(hist).as_double(), 0);
}
TEST(vector_input, negative_element)
{
std::vector<int> hist{-1};
EXPECT_THROW(median::from_histogram(hist | median::checked_histogram), std::domain_error);
}
TEST(vector_input, one_TWO)
{
std::vector<int> hist{1,2};
EXPECT_EQ(median::from_histogram(hist).as_double(), 1);
EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 0);
}
TEST(vector_input, all_zero)
{
std::vector<int> hist{0, 0, 0, 0};
EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}
TEST(vector_input, all_one)
{
std::vector<int> hist{1, 1, 1, 1};
EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}
TEST(vector_input, ones_zero_TWO_one)
{
std::vector<int> hist{1, 1, 0, 2, 1};
EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 1);
}
TEST(vector_input, ones_ZEROS_ones)
{
std::vector<int> hist{1, 1, 0, 0, 0, 1, 1};
EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 3);
}
TEST(vector_input, max_MAX_max)
{
std::vector<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, 0};
EXPECT_EQ(median::from_histogram(hist).as_double(), 2);
}
TEST(vector_input, max_MAX_MAX_max)
{
std::vector<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, UINT_MAX, 0};
EXPECT_EQ(median::from_histogram(hist).as_double(), 2.5);
}
TEST(map_input, TEN)
{
std::map<double,unsigned> hist{{10, 1}};
EXPECT_EQ(median::from_histogram(hist), 10);
}
TEST(map_input, zero_TEN_FIFTEEN_twenty)
{
std::map<double,unsigned> hist{{0, 4}, {10.5, 1}, {15, 2}, {20, 3}};
EXPECT_EQ(median::from_histogram(hist), 12.75);
}
TEST(list_input, empty)
{
std::forward_list<int> empty;
EXPECT_THROW(median::from_histogram(empty), std::domain_error);
}
TEST(list_input, one_element)
{
std::forward_list<int> hist{0};
EXPECT_EQ(median::from_histogram(hist).as_double(), 0);
}
TEST(list_input, negative_element)
{
std::forward_list<int> hist{-1};
EXPECT_THROW(median::from_histogram(hist | median::checked_histogram), std::domain_error);
}
TEST(list_input, one_TWO)
{
std::forward_list<int> hist{1,2};
EXPECT_EQ(median::from_histogram(hist).as_double(), 1);
}
TEST(list_input, all_zero)
{
std::forward_list<int> hist{0, 0, 0, 0};
EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}
TEST(list_input, all_one)
{
std::forward_list<int> hist{1, 1, 1, 1};
EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}
TEST(list_input, zero_ONES_zero)
{
std::forward_list<int> hist{0, 1, 1, 0};
EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}
TEST(list_input, ones_zero_TWO_one)
{
std::forward_list<int> hist{1, 1, 0, 2, 1};
EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
}
TEST(list_input, ones_ZEROS_ones)
{
std::forward_list<int> hist{1, 1, 0, 0, 0, 1, 1};
EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
}
TEST(list_input, max_MAX_max)
{
std::forward_list<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, 0};
EXPECT_EQ(median::from_histogram(hist).as_double(), 2);
}
TEST(list_input, max_MAX_MAX_max)
{
std::forward_list<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, UINT_MAX, 0};
EXPECT_EQ(median::from_histogram(hist).as_double(), 2.5);
}
TEST(list_input_with_total, empty)
{
std::forward_list<int> empty;
EXPECT_THROW(median::from_histogram(0, empty), std::domain_error);
}
TEST(list_input_with_total, one_element)
{
std::forward_list<int> hist{0};
EXPECT_EQ(median::from_histogram(0, hist).as_double(), 0);
}
TEST(list_input_with_total, negative_element)
{
std::forward_list<int> hist{-1};
EXPECT_THROW(median::from_histogram(1, hist | median::checked_histogram), std::domain_error);
}
TEST(list_input_with_total, one_TWO)
{
std::forward_list<int> hist{1,2};
EXPECT_EQ(median::from_histogram(3, hist).as_double(), 1);
}
TEST(list_input_with_total, all_zero)
{
std::forward_list<int> hist{0, 0, 0, 0};
EXPECT_EQ(median::from_histogram(0, hist).as_double(), 1.5);
}
TEST(list_input_with_total, all_one)
{
std::forward_list<int> hist{1, 1, 1, 1};
EXPECT_EQ(median::from_histogram(4, hist).as_double(), 1.5);
}
TEST(list_input_with_total, ones_zero_TWO_one)
{
std::forward_list<int> hist{1, 1, 0, 2, 1};
EXPECT_EQ(median::from_histogram(5, hist).as_double(), 3);
}
TEST(list_input_with_total, ones_ZEROS_ones)
{
std::forward_list<int> hist{1, 1, 0, 0, 0, 1, 1};
EXPECT_EQ(median::from_histogram(4, hist).as_double(), 3);
}
TEST(list_input_with_total, MAX_one)
{
std::forward_list<unsigned int> hist{UINT_MAX-1, 1};
EXPECT_EQ(median::from_histogram(UINT_MAX, hist).as_double(), 0);
}
TEST(list_input_with_total, one_MAX)
{
std::forward_list<unsigned int> hist{1, UINT_MAX-1};
EXPECT_EQ(median::from_histogram(UINT_MAX, hist).as_double(), 1);
}
TEST(list_input_with_total, max_max)
{
// Misuse - total is not correct
std::forward_list<unsigned int> hist{UINT_MAX/2, UINT_MAX};
EXPECT_THROW(median::from_histogram(UINT_MAX, hist).as_double(), std::domain_error);
}
```
1 Answer 1
It looks very complicated, and it does some weird things:
Weird result types
I would expect a median function to return a simple integer or floating point, not a result<T>
. For floating point, I see no point at all; you could just return std::midpoint(left, right)
unconditionally. For integers, things are indeed a bit more complicated. There is the rounding part when one wants an integer result, I get that. But why check for right < left
? Wouldn't the histograms be sorted? If not, does it even make sense to ask for a median?
It looks like a lot of work for the case where one wants a double
result from a histogram of integral values. Wouldn't it be easier to just add a template parameter to from_histogram()
to set the result type? So that you would write this:
std::vector<int> hist{1, 1, 1, 1};
EXPECT_EQ(median::from_histogram<double>(hist), 1.5);
Furthermore, you only allow std::integral
and std::floating_point
types. But what about other types that are ordered and would in principle allow a midpoint to be calculated, like for example a custom fraction or bignum type?
Consider using a projection function
You have 5 overloads for from_histogram()
, 2 to handle ranges with an implicit value, 3 to handle tuples of at least 2 values, where the first tuple element is taken to be the value, and the second the count. What if I have a range of 3-tuples and the count I'm interested in is in the third element? What if my value/count pairs are in a regular struct
instead of in a std::tuple
?
Consider allowing a projection function to be passed to from_histogram()
that allows customizing how value and count information is extracted from the input range. You can provide a default function that does the equivalent of your current overloads:
template<std::ranges::forward_range R, class Proj = std::identity>
requires (/* Proj valid for R */)
auto from_histogram(R const& input, Proj proj = {}) {
if constexpr (/* projection returns a single value */) {
/* enumerate */
...
} else if constexpr (/* projection returns a 2-tuple */) {
/* split return value of projection into value and count */
...
} else {
/* return a compile error */
}
}
Overflow handling
I question any code which would produce a histogram whose counts sum to more than can be stored in the count type; it would likely be buggy itself, as a different distribution of values could then already cause an overflow. But your algorithm can handle that, which is good.
I wonder if left_sum * 2
could wrap? If so, your algorithm could return an incorrect result.