Recently I found that the binary search (std::ranges::lower_bound
and std::ranges::upper_bound
) is the main bottleneck in my library.
So I wanted to improve it, one attempt was to use SIMD to replace the C++ standard library binary search. In short, the result was actually little bit slower than original binary search.
But I don't want to just discard it away, so I upload it here.
#include <algorithm>
#include <concepts>
#include <cstdint>
template <typename K>
concept CanUseSimd = (sizeof(K) == 4 || sizeof(K) == 8) && (std::signed_integral<K> || std::floating_point<K>);
using regi = __m512i;
using regf = __m512;
using regd = __m512d;
unsigned int cmp(std::int32_t key, const std::int32_t* key_ptr) {
regi key_broadcasted = _mm512_set1_epi32(key);
regi keys_to_comp = _mm512_load_si512(reinterpret_cast<const regi*>(key_ptr));
return _mm512_cmpgt_epi32_mask(key_broadcasted, keys_to_comp);
}
unsigned int cmp(std::int64_t key, const std::int64_t* key_ptr) {
regi key_broadcasted = _mm512_set1_epi64(key);
regi keys_to_comp = _mm512_load_si512(reinterpret_cast<const regi*>(key_ptr));
return _mm512_cmpgt_epi64_mask(key_broadcasted, keys_to_comp);
}
unsigned int cmp(float key, const float* key_ptr) {
regf key_broadcasted = _mm512_set1_ps(key);
regf keys_to_comp = _mm512_load_ps(key_ptr);
return _mm512_cmp_ps_mask(key_broadcasted, keys_to_comp, _MM_CMPINT_GT);
}
unsigned int cmp(double key, const double* key_ptr) {
regd key_broadcasted = _mm512_set1_pd(key);
regd keys_to_comp = _mm512_load_pd(key_ptr);
return _mm512_cmp_pd_mask(key_broadcasted, keys_to_comp, _MM_CMPINT_GT);
}
unsigned int cmp(const std::int32_t* key_ptr, std::int32_t key) {
regi key_broadcasted = _mm512_set1_epi32(key);
regi keys_to_comp = _mm512_load_si512(reinterpret_cast<const regi*>(key_ptr));
return _mm512_cmpgt_epi32_mask(keys_to_comp, key_broadcasted);
}
unsigned int cmp(const std::int64_t* key_ptr, std::int64_t key) {
regi key_broadcasted = _mm512_set1_epi64(key);
regi keys_to_comp = _mm512_load_si512(reinterpret_cast<const regi*>(key_ptr));
return _mm512_cmpgt_epi64_mask(keys_to_comp, key_broadcasted);
}
unsigned int cmp(const float* key_ptr, float key) {
regf key_broadcasted = _mm512_set1_ps(key);
regf keys_to_comp = _mm512_load_ps(key_ptr);
return _mm512_cmp_ps_mask(keys_to_comp, key_broadcasted, _MM_CMPINT_GT);
}
unsigned int cmp(const double* key_ptr, double key) {
regd key_broadcasted = _mm512_set1_pd(key);
regd keys_to_comp = _mm512_load_pd(key_ptr);
return _mm512_cmp_pd_mask(keys_to_comp, key_broadcasted, _MM_CMPINT_GT);
}
template <CanUseSimd K>
struct SimdTrait {
static constexpr int shift = (sizeof(K) == 4) ? 4 : 3;
static constexpr int mask = (sizeof(K) == 4) ? 0xF : 0x7;
static constexpr int unit = (sizeof(K) == 4) ? 16 : 8;
};
template <CanUseSimd K, bool less>
inline std::int32_t get_lb_simd(K key, const K* first, const K* last) {
auto len = static_cast<std::int32_t>(last - first);
// make to the least multiple of SimdUnit which is at least len
len = ((len >> SimdTrait<K>::shift) + ((len & SimdTrait<K>::mask) ? 1 : 0)) << SimdTrait<K>::shift;
const K* curr = first;
std::int32_t i = 0;
int mask = 0;
auto half = (len >> (SimdTrait<K>::shift + 1)) << SimdTrait<K>::shift;
while (len > SimdTrait<K>::unit) {
len -= half;
auto next_half = (len >> (SimdTrait<K>::shift + 1)) << SimdTrait<K>::shift;
__builtin_prefetch(curr + next_half - SimdTrait<K>::unit);
__builtin_prefetch(curr + half + next_half - SimdTrait<K>::unit);
auto mid = curr + half - SimdTrait<K>::unit;
if constexpr (less) {
mask = ~cmp(key, mid);
} else {
mask = ~cmp(mid, key);
}
i = __builtin_ffs(mask) - 1;
curr += (i == SimdTrait<K>::unit) * half;
if (i & SimdTrait<K>::mask) {
return static_cast<std::int32_t>(mid - first) + i;
}
half = next_half;
}
if constexpr (less) {
mask = ~cmp(key, curr);
} else {
mask = ~cmp(curr, key);
}
i = __builtin_ffs(mask) - 1;
return std::min(static_cast<std::int32_t>(last - first), static_cast<std::int32_t>(curr - first) + i);
}
template <CanUseSimd K, bool less>
inline std::int32_t get_ub_simd(K key, const K* first, const K* last) {
auto len = static_cast<std::int32_t>(last - first);
// make to the least multiple of SimdUnit which is at least len
len = ((len >> SimdTrait<K>::shift) + ((len & SimdTrait<K>::mask) ? 1 : 0)) << SimdTrait<K>::shift;
const K* curr = first;
std::int32_t i = 0;
int mask = 0;
while (len > SimdTrait<K>::unit) {
auto half = (len >> (SimdTrait<K>::shift + 1)) << SimdTrait<K>::shift;
len -= half;
auto mid = curr + half - SimdTrait<K>::unit;
if constexpr (less) {
mask = cmp(mid, key);
} else {
mask = cmp(key, mid);
}
i = __builtin_ffs(mask) - 1;
curr += (mask == 0) * half;
if (i > 0) {
return static_cast<std::int32_t>(mid - first) + i;
}
}
if constexpr (less) {
mask = cmp(curr, key);
} else {
mask = cmp(key, curr);
}
i = (mask == 0) ? len : __builtin_ffs(mask) - 1;
return std::min(static_cast<std::int32_t>(last - first), static_cast<std::int32_t>(curr - first) + i);
}
1 Answer 1
Use the same API as std::ranges::lower_bound
You mention that you wanted to use std::ranges
algorithms to do a binary search. If you implement your own functions, try to keep the same interface as the standard library does. This allows your functions to be used as drop-in replacements for the STL ones, making porting the code easier.
The STL algorithms either get passed two iterators or a range. You can use the std::contiguous_iterator
concept to restrict those iterators to ones that are suitable for SIMD operations.
I would also make sure your function names match those of the STL, but put them in your own namespace. This way, you can write:
std::vector<int> data;
...
auto result = simd::lower_bound(data.begin(), data.end(), 42);
And if you were using using namespace std::ranges
before, you can simply add or change that to using namespace simd
without having to change any calls to lower_bound()
, and even better: since you are using concepts, this can be made to transparently fall back to the std::ranges
versions for types that are not supported by your SIMD versions.
You'll also notice that the STL either returns iterators or std::size_t
values. Don't cast lengths and indices to std::uint32_t
; it will not give you any speedup and it might result in incorrectly working code if someone wants to run it on arrays larger than 4 gigabytes.
Most STL algorithms that need to compare elements of a container use std::less
by default, and allow you to pass a custom comparator. Again, just follow this pattern, but make SIMD versions, so that your lower_bound()
looks like:
namespace simd {
template<CanUseSimd T>
struct less;
template<CanUseSimd T>
struct greater;
template<>
struct less<std::int32_t> {
unsigned int operator()(std::int32_t* data, std::int32_t key) {
regi keys = _mm512_set1_epi32(key);
regi values = _mm512_load_si512(reinterpret_cast<const regi*>(data));
return _mm512_cmplt_epi32_mask(keys, values);
}
};
template<>
struct greater<std::int32_t> {
unsigned int operator()(std::int32_t* data, std::int32_t key) {
...
return _mm512_cmpgt_epi32_mask(keys, values);
}
};
...
template<std::contiguous_iterator I, std::sentinel_for<I> S, class T, class Comp = less>
requires CanUseSimd<I::value_type>
I lower_bound(I first, S last, const T& value, Comp comp = {});
}
Use the C++ standard library where possible
The standard library has lots of functionality, and with every new release of the standard, more is added. Nowadays it is rare that you have to fall back to C functions. Consider __builtin_ffs()
: since C++20, you can now use std::countr_zero()
.
To find the halfway point between to values, or even two pointers, there's std::midpoint()
. Consider creating your own SIMD version of it:
template<CanUseSimd K>
const K* midpoint(const K* a, const K* b) {
auto len = b - a;
auto half = (len / 2) & ~SimdTrait<K>::mask;
return a + half;
}
Prefetching might not be worth it
I notice you are using __builtin_prefetch()
in get_lb_simd()
, but not in get_ub_simd()
. You are also fetching two cachelines: one for the lower half, and one for the upper half. But you'll only use one of those. Thus, you are using twice as much memory bandwidth as necessary. The latency of the AVX-512 mask compare instructions is only 3 according to the Intel® Intrinsics Guide, so while prefetching the right cacheline might hide a little bit of latency, prefetching the wrong one as well will probably negate any benefit. Then again, it heavily depends on the microarchitecture and the memory system, so you will have to benchmark it with and without prefetching to see what the impact is.
Consider using a SIMD library
Using intrinsics has some drawbacks: apart from making it hard to port your code to other CPU architectures like Arm, you have to write a lot of code yourself. Consider using a C++ library that provides you with the same functionality, but in a generic way, like Google Highway. This should simplify your own code considerably.
Explore related questions
See similar questions with these tags.
len
<SimdTrait<K>::unit
? (Isn't it a pity to broadcast the same needle over and again?) \$\endgroup\$