Skip to main content
Code Review

Return to Question

added 49 characters in body
Source Link
Toby Speight
  • 87.9k
  • 14
  • 104
  • 325
  • Find other percentiles, not just median.
  • Support "wide" bins as well as single values.
  • Validation adapter to check that the values are ordered.
  • Validation of known-total computation to ensure that the total is correct.
  • Find other percentiles, not just median.
  • Validation adapter to check that the values are ordered.
  • Validation of known-total computation to ensure that the total is correct.
  • Find other percentiles, not just median.
  • Support "wide" bins as well as single values.
  • Validation adapter to check that the values are ordered.
  • Validation of known-total computation to ensure that the total is correct.
Source Link
Toby Speight
  • 87.9k
  • 14
  • 104
  • 325

Compute median of a histogram

Another look at finding median value. This time the input is a histogram represented as an ordered range of (value, count) pairs, or simply as a range of counts, with value inferred from position. I've taken care to avoid integer overflow in the computation, which unfortunately complicates the code more than I would like.

There's a simpler implementation for bidirectional ranges, and a more general version for when the total count is known (this is useful for image histograms, where the total is width ✕ height).

I have provided an adapter to ensure counts are in the range [0, +∞). This is to be explicitly applied by user, rather than built into the functions, so that it imposes no penalty on code that doesn't use it. Other features I might add in future, but not provided here, are:

  • Find other percentiles, not just median.
  • Validation adapter to check that the values are ordered.
  • Validation of known-total computation to ensure that the total is correct.
#include <cmath>
#include <concepts>
#include <iterator>
#include <numeric>
#include <stdexcept>
#include <ranges>
#include <tuple>
namespace median
{
 // Median result types - allow indication of integer-and-a-half
 template<typename T>
 struct result {};
 template<std::floating_point T>
 struct result<T> {
 T value;
 result(T mid) : value{mid} {}
 result(T left, T right) : value{std::midpoint(left, right)} {}
 operator T() const { return value; }
 auto as_double() const { return value; }
 };
 template<std::integral T>
 struct result<T> {
 T whole_part;
 bool plus_half;
 result(T mid)
 : whole_part{mid},
 plus_half{false}
 {}
 result(T left, T right)
 : whole_part{std::midpoint(left, right)},
 plus_half{left % 2 != right % 2}
 {
 if (plus_half && right < left) { --whole_part; }
 }
 auto as_double() const { return static_cast<double>(whole_part) + 0.5 * plus_half; }
 explicit operator T() const { return whole_part; }
 explicit operator double() const { return as_double(); }
 };
 // Range adapters for verifying range values
 auto const checked_histogram = std::views::transform([](auto&& val) {
 if constexpr (requires{ std::get<1>(val); }) {
 auto const& [i,count] = val;
 if (count < 0 || !std::isfinite(count)) {
 throw std::domain_error("invalid histogram entry");
 }
 } else {
 if (val < 0 || !std::isfinite(val)) {
 throw std::domain_error("invalid histogram entry");
 }
 }
 return val;
 });
 // A histogram is an ordered range of {value, count} pairs.
 // Alternatively, a range of counts can be provided and the
 // values 0, 1, 2, ... will be inferred.
 // Median of a bidirectional histogram
 template<std::ranges::bidirectional_range R>
 auto from_histogram(R const& input)
 -> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
 requires requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
 {
 auto left = input.cbegin();
 auto right = input.cend();
 if (left == right) {
 throw std::domain_error("empty histogram");
 }
 auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
 auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };
 auto constexpr has_positive_count = [](auto const& pair){ return std::get<1>(pair) > 0; };
 auto left_sum = count(left);
 auto right_sum = count(--right);
 while (left != right) {
 // Reduce sums so that at least one of them is zero
 if (left_sum > right_sum) {
 left_sum -= right_sum;
 right_sum = 0;
 } else {
 right_sum -= left_sum;
 left_sum = 0;
 }
 // advance one of the iterators
 if (left_sum) {
 right_sum += count(--right);
 } else if (right_sum) {
 left_sum += count(++left);
 } else {
 // left and right sums both zero
 auto const it = std::find_if(std::next(left), right, has_positive_count);
 if (it == right) {
 return {value(left), value(right)};
 }
 left_sum += count(left = it);
 }
 }
 return value(left);
 }
 // Median of a forward-only histogram
 template<std::ranges::forward_range R>
 auto from_histogram(R const& input)
 -> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
 requires (not std::ranges::bidirectional_range<R>)
 and requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
 {
 auto constexpr has_positive_count = [](auto const& pair){ return std::get<1>(pair) > 0; };
 auto left = input.cbegin();
 auto right = left;
 auto const last = input.cend();
 if (right == last) {
 throw std::domain_error("empty histogram");
 }
 auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
 auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };
 auto left_sum = count(left) + 0; // addition promotes smaller types to signed/unsigned int
 auto right_sum = count(right) + 0;
 auto constexpr addition_would_overflow = [](auto augend, auto addend)
 {
 // Neither argument is negative.
 return addend > std::numeric_limits<decltype(augend)>::max() - augend;
 };
 auto constexpr reduce = [](auto& left_sum, auto &right_sum) {
 // // Reduce sums
 if (left_sum > 1 && right_sum > 2) {
 auto subtrahend = std::min(left_sum - 1, (right_sum - 1) / 2);
 left_sum -= subtrahend;
 right_sum -= subtrahend * 2;
 }
 };
 using std::next;
 while (next(right) != last) {
 reduce(left_sum, right_sum);
 {
 // advance right
 auto right_addend = count(++right);
 reduce(left_sum, right_addend);
 while (addition_would_overflow(right_sum, right_addend)) {
 auto left_addend = count(++left);
 reduce(left_addend, right_addend);
 left_sum += left_addend;
 if (left == right) { break; }
 }
 right_sum += right_addend;
 }
 while (!addition_would_overflow(left_sum, left_sum) && left_sum + left_sum < right_sum) {
 // advance left until it reaches right/2
 auto left_addend = count(++left);
 reduce(left_addend, right_sum);
 while (addition_would_overflow(left_sum, left_addend)) {
 auto right_addend = count(++right);
 reduce(left_addend, right_addend);
 right_sum += right_addend;
 if (next(right) == last) {
 break;
 }
 }
 left_sum += left_addend;
 }
 }
 if (left_sum * 2 == right_sum) {
 // tie break
 if (left == right) [[unlikely]] {
 // only happens with {0} as input
 return value(left);
 }
 auto const it = std::ranges::find_if(std::next(left), right, has_positive_count);
 return {value(left), value(it)};
 }
 return value(left);
 }
 template<std::ranges::forward_range R>
 auto from_histogram(R const& input)
 requires std::assignable_from<double&, std::ranges::range_value_t<R>>
 {
 return median::from_histogram(input | std::views::enumerate);
 }
 // Median of any histogram, when total population is known
 template<typename T, std::ranges::forward_range R>
 auto from_histogram(T const& total, R const& input)
 -> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
 requires requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
 {
 auto iter = input.cbegin();
 auto const last = input.cend();
 if (iter == last) {
 throw std::domain_error("empty histogram");
 }
 auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
 auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };
 if (total == 0) [[unlikely]] {
 // Try to return something sensible
 auto const left_val = value(iter++);
 auto right_val = left_val;
 while (iter != last) {
 right_val = value(iter++);
 }
 return {left_val, right_val};
 }
 T sum = 0;
 do {
 auto addend = count(iter);
 if (addend > total - sum) {
 throw std::domain_error("overpopulated histogram");
 }
 sum += count(iter);
 if (sum > total / 2) {
 return value(iter);
 }
 if (2 * sum == total) {
 // find midpoint
 auto const left_val = value(iter);
 iter = std::ranges::find_if(++iter, last, [](auto val){ return std::get<1>(val) != 0; });
 if (iter == last) { break; }
 return {left_val, value(iter)};
 }
 } while (++iter != last);
 // ran off the end
 throw std::domain_error("underpopulated histogram");
 }
 template<std::ranges::input_range R, std::convertible_to<std::ranges::range_value_t<R>> T>
 auto from_histogram(T total, R const& input)
 {
 return median::from_histogram(total, input | std::views::enumerate);
 }
}

I used these tests to write the preceding functions:

#include <gtest/gtest.h>
#include <climits>
#include <forward_list>
#include <map>
#include <vector>
TEST(vector_input, empty)
{
 std::vector<int> empty;
 EXPECT_THROW(median::from_histogram(empty), std::domain_error);
}
TEST(vector_input, one_element)
{
 std::vector<int> hist{0};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 0);
}
TEST(vector_input, negative_element)
{
 std::vector<int> hist{-1};
 EXPECT_THROW(median::from_histogram(hist | median::checked_histogram), std::domain_error);
}
TEST(vector_input, one_TWO)
{
 std::vector<int> hist{1,2};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 1);
 EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 0);
}
TEST(vector_input, all_zero)
{
 std::vector<int> hist{0, 0, 0, 0};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}
TEST(vector_input, all_one)
{
 std::vector<int> hist{1, 1, 1, 1};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}
TEST(vector_input, ones_zero_TWO_one)
{
 std::vector<int> hist{1, 1, 0, 2, 1};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
 EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 1);
}
TEST(vector_input, ones_ZEROS_ones)
{
 std::vector<int> hist{1, 1, 0, 0, 0, 1, 1};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
 EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 3);
}
TEST(vector_input, max_MAX_max)
{
 std::vector<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, 0};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 2);
}
TEST(vector_input, max_MAX_MAX_max)
{
 std::vector<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, UINT_MAX, 0};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 2.5);
}
TEST(map_input, TEN)
{
 std::map<double,unsigned> hist{{10, 1}};
 EXPECT_EQ(median::from_histogram(hist), 10);
}
TEST(map_input, zero_TEN_FIFTEEN_twenty)
{
 std::map<double,unsigned> hist{{0, 4}, {10.5, 1}, {15, 2}, {20, 3}};
 EXPECT_EQ(median::from_histogram(hist), 12.75);
}
TEST(list_input, empty)
{
 std::forward_list<int> empty;
 EXPECT_THROW(median::from_histogram(empty), std::domain_error);
}
TEST(list_input, one_element)
{
 std::forward_list<int> hist{0};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 0);
}
TEST(list_input, negative_element)
{
 std::forward_list<int> hist{-1};
 EXPECT_THROW(median::from_histogram(hist | median::checked_histogram), std::domain_error);
}
TEST(list_input, one_TWO)
{
 std::forward_list<int> hist{1,2};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 1);
}
TEST(list_input, all_zero)
{
 std::forward_list<int> hist{0, 0, 0, 0};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}
TEST(list_input, all_one)
{
 std::forward_list<int> hist{1, 1, 1, 1};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}
TEST(list_input, zero_ONES_zero)
{
 std::forward_list<int> hist{0, 1, 1, 0};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}
TEST(list_input, ones_zero_TWO_one)
{
 std::forward_list<int> hist{1, 1, 0, 2, 1};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
}
TEST(list_input, ones_ZEROS_ones)
{
 std::forward_list<int> hist{1, 1, 0, 0, 0, 1, 1};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
}
TEST(list_input, max_MAX_max)
{
 std::forward_list<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, 0};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 2);
}
TEST(list_input, max_MAX_MAX_max)
{
 std::forward_list<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, UINT_MAX, 0};
 EXPECT_EQ(median::from_histogram(hist).as_double(), 2.5);
}
TEST(list_input_with_total, empty)
{
 std::forward_list<int> empty;
 EXPECT_THROW(median::from_histogram(0, empty), std::domain_error);
}
TEST(list_input_with_total, one_element)
{
 std::forward_list<int> hist{0};
 EXPECT_EQ(median::from_histogram(0, hist).as_double(), 0);
}
TEST(list_input_with_total, negative_element)
{
 std::forward_list<int> hist{-1};
 EXPECT_THROW(median::from_histogram(1, hist | median::checked_histogram), std::domain_error);
}
TEST(list_input_with_total, one_TWO)
{
 std::forward_list<int> hist{1,2};
 EXPECT_EQ(median::from_histogram(3, hist).as_double(), 1);
}
TEST(list_input_with_total, all_zero)
{
 std::forward_list<int> hist{0, 0, 0, 0};
 EXPECT_EQ(median::from_histogram(0, hist).as_double(), 1.5);
}
TEST(list_input_with_total, all_one)
{
 std::forward_list<int> hist{1, 1, 1, 1};
 EXPECT_EQ(median::from_histogram(4, hist).as_double(), 1.5);
}
TEST(list_input_with_total, ones_zero_TWO_one)
{
 std::forward_list<int> hist{1, 1, 0, 2, 1};
 EXPECT_EQ(median::from_histogram(5, hist).as_double(), 3);
}
TEST(list_input_with_total, ones_ZEROS_ones)
{
 std::forward_list<int> hist{1, 1, 0, 0, 0, 1, 1};
 EXPECT_EQ(median::from_histogram(4, hist).as_double(), 3);
}
TEST(list_input_with_total, MAX_one)
{
 std::forward_list<unsigned int> hist{UINT_MAX-1, 1};
 EXPECT_EQ(median::from_histogram(UINT_MAX, hist).as_double(), 0);
}
TEST(list_input_with_total, one_MAX)
{
 std::forward_list<unsigned int> hist{1, UINT_MAX-1};
 EXPECT_EQ(median::from_histogram(UINT_MAX, hist).as_double(), 1);
}
TEST(list_input_with_total, max_max)
{
 // Misuse - total is not correct
 std::forward_list<unsigned int> hist{UINT_MAX/2, UINT_MAX};
 EXPECT_THROW(median::from_histogram(UINT_MAX, hist).as_double(), std::domain_error);
}
```
lang-cpp

AltStyle によって変換されたページ (->オリジナル) /