I'm trying to optimise my implementation of a static kd tree to perform orthogonal range searches in C++.
My problem: The code performs slowly even for a small number of queries when the number of points is around 105.
I've constructed the tree based on this (i.e, It's a static kd tree where the data is stored both in the nodes and the leaves) and the orthogonal range searching algorithm based on Ch 11 of this.
Relevant algorithm: (My implementation returns the points instead of count)
int rangeCount(Range Q, KDNode t, Rectangle C)
(1) if (t is a leaf)
(a) if (Q contains t) return 1,
(b) else return 0.
(2) if (t is not a leaf)
(a) if (C does not intersect Q) return 0.
(b) else if (C is a subset of Q) return t:size.
(c) else, split C along t’s cutting dimension and cutting value, letting C1 and C2 be the two rectangles. Return (rangeCount(Q; t:left; C1) + rangeCount(Q; t:right; C2)).
My code (the relevant function is the 2nd recursive_query()
):
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
#include <random>
#include <chrono>
using namespace std;
typedef long long ll;
typedef pair<int, int> T;
typedef pair<pair<int, int>, pair<int, int> > R;
bool comp_x(const T &a, const T &b)
{
return a.first < b.first;
}
bool comp_y(const T &a, const T &b)
{
return a.second < b.second;
}
template<int k=2>
class kd_tree
{
private:
vector<T> array; // array[0] is a dummy
int N;
void recursive_build(int depth, int node, int begin, int end, vector<T> &elements)
{
if(end < begin)
return;
if(begin == end){
array[node] = elements[begin];
return;
}
int axis = depth % k;
int median;
// Find median for axis (Todo: Implement quickselect)
if(axis == 0) // x-axis
{
sort(elements.begin()+begin, elements.begin()+end+1, comp_x);
median = (begin+end+1)/2;
}
else
{
sort(elements.begin()+begin, elements.begin()+end+1, comp_y);
median = (begin+end+1)/2;
}
array[node] = elements[median];
recursive_build(depth+1, 2*node, begin, median-1, elements);
recursive_build(depth+1, 2*node+1, median+1, end, elements);
}
void return_subtree(int node, vector<T> &query_list)
{
if(node > N || array[node].first == 0)
return;
query_list.push_back(array[node]);
return_subtree(2*node, query_list);
return_subtree(2*node+1, query_list);
}
void recursive_query(int depth, int node, int x1, int y1, int x2, int y2, vector<T> &query)
{
if(node > N)
return;
int axis = depth % k;
if(axis == 0) // x-axis
{
if(array[node].first < x1)
{
recursive_query(depth+1, 2*node+1, x1, y1, x2, y2, query);
}
else if(array[node].first > x2)
{
recursive_query(depth+1, 2*node, x1, y1, x2, y2, query);
}
else
{
if(array[node].second >= y1 && array[node].second <= y2)
query.push_back(array[node]);
recursive_query(depth+1, 2*node, x1, y1, x2, y2, query);
recursive_query(depth+1, 2*node+1, x1, y1, x2, y2, query);
}
}
else
{
if(array[node].second < y1)
{
recursive_query(depth+1, 2*node+1, x1, y1, x2, y2, query);
}
else if(array[node].second > y2)
{
recursive_query(depth+1, 2*node, x1, y1, x2, y2, query);
}
else
{
if(array[node].first >= x1 && array[node].first <= x2)
query.push_back(array[node]);
recursive_query(depth+1, 2*node, x1, y1, x2, y2, query);
recursive_query(depth+1, 2*node+1, x1, y1, x2, y2, query);
}
}
}
void recursive_query(int depth, int node, R cell, R query, vector<T> &query_list)
{
int left = 2*node;
int right = 2*node+1;
// node is a leaf node
if(array[left].first == 0 && array[right].first == 0)
{
// if the leaf lies inside the query rectangle then add it to query_list
if(array[node].first >= query.first.first && array[node].first <= query.second.first
&& array[node].second >= query.first.second && array[node].second <= query.second.second)
{
query_list.push_back(array[node]);
}
return;
}
// node is not a leaf
// cell doesnt intersect the query
if(cell.first.first > query.second.first || cell.second.first < query.first.first
|| cell.first.second > query.second.second || cell.second.second < query.first.second)
{
return;
}
// cell is a subset of query
if(cell.first.first >= query.first.first && cell.second.first <= query.second.first
&& cell.first.second >= query.first.second && cell.second.second <= query.second.second)
{
return_subtree(node, query_list);
return;
}
// if the node lies within bounds then add it to query_list
if(array[node].first >= query.first.first && array[node].first <= query.second.first
&& array[node].second >= query.first.second && array[node].second <= query.second.second)
{
query_list.push_back(array[node]);
}
depth = depth % k;
if(depth == 0) // splitting planes is the x-axis
{
if(array[left].first != 0){
R cell1 = make_pair(cell.first, make_pair(array[node].first, cell.second.second));
recursive_query(depth+1, left, cell1, query, query_list);
}
if(array[right].first != 0){
R cell2 = make_pair(make_pair(array[node].first, cell.first.second), cell.second);
recursive_query(depth+1, right, cell2, query, query_list);
}
}
else // splitting plane is the y-axis
{
if(array[left].first != 0){
R cell1 = make_pair(cell.first, make_pair(cell.second.first, array[node].second));
recursive_query(depth+1, left, cell1, query, query_list);
}
if(array[right].first != 0){
R cell2 = make_pair(make_pair(cell.first.first, array[node].second), cell.second);
recursive_query(depth+1, right, cell2, query, query_list);
}
}
}
public:
kd_tree(vector<T> &elements)
{
N = elements.size();
array.resize(2*k*N+1);
recursive_build(0, 1, 0, N-1, elements);
}
void orthogonal_query(int x1, int y1, int x2, int y2, vector<T> &query_list)
{
// recursive_query(0, 1, x1, y1, x2, y2, query_list);
recursive_query(0, 1, make_pair(make_pair(1, 1), make_pair(300000, 300000)), make_pair(make_pair(x1, y1), make_pair(x2, y2)), query_list);
}
};
int main(void)
{
std::mt19937_64 generator;
std::uniform_int_distribution<int> distribution(0, 300000);
int n, q, x, y, d, count;
// scanf("%d %d", &n, &q);
n = 30000;
q = 2000;
vector<pair<int, int> > points, query;
vector<pair<int, int> >::iterator itr;
// Input Phase
while (n--){
// scanf("%d %d", &x, &y);
x = distribution(generator);
y = distribution(generator);
points.push_back(make_pair(x, y));
}
kd_tree<2> tree(points);
std::chrono::time_point<std::chrono::high_resolution_clock> start, end;
start = std::chrono::high_resolution_clock::now();
// Query Loop
while(q--){
query.clear();
// scanf("%d %d %d", &x, &y, &d);
x = distribution(generator);
y = distribution(generator);
d = distribution(generator);
int d1 = x+y+d;
tree.orthogonal_query(x, y, x+d, y+d, query);
// printf("%d\n", count);
}
end = std::chrono::high_resolution_clock::now();
ll elapsed_time = std::chrono::duration_cast<std::chrono::milliseconds>(end-start).count();
cout << "\nElapsed Time: " << elapsed_time << "ms\n";
return 0;
}
-
\$\begingroup\$ That's a lot of code to have us look at. Might I consider using a profiler? cs.utah.edu/dept/old/texinfo/as/gprof_toc.html \$\endgroup\$Louis– Louis2013年02月08日 04:24:45 +00:00Commented Feb 8, 2013 at 4:24
-
\$\begingroup\$ @Louis Most of the time seems to be spend in the return_subtree function. Commenting out this function seemed to improve the time taken by a factor of roughly log(n)/10, bringing the time for n = 10^5 and number of queries, q = 10^5 to around a few seconds. So I guess i'll have to modify my kdtree to preprocess this part \$\endgroup\$Ishan Bhatnagar– Ishan Bhatnagar2013年02月08日 05:00:54 +00:00Commented Feb 8, 2013 at 5:00
-
1\$\begingroup\$ @IshanBhatnagar: I really encourage you to use structures instead of "pairs", a pair of pairs of int is not too self-describing. A pair of points already makes a bit more sense. A rectangle (defined by two points) is even better. \$\endgroup\$Matthieu M.– Matthieu M.2013年02月08日 07:45:27 +00:00Commented Feb 8, 2013 at 7:45
1 Answer 1
You could make the return_subtree
function non-recursive by using a stack
object by yourself:
void return_subtree(int node, vector<T> &query_list)
{
if (node > N) { return; }
std::stack<int> st;
st.push(node);
while (not st.empty()) {
int top = st.top();
st.pop();
// if array[top].first == 0, then the recursion ends...
// ... BUT it still need be returned!
if (top > N) { continue; }
query_list.push_back(array[top]);
if(array[top].first == 0) { continue; }
st.push(2 * top + 1);
st.push(2*node+1);
}
}
This could lead to a speed up because return_subtree
cannot be tail-call optimized (or at least, only the latest call may be optimized), and therefore the compiler is probably unable to come up with this optimization by itself...
... which of course brings us to the crux of the matter: turn your compiler optimizations on (-Os or -O2 flag on the command line for gcc/clang).
Other than that, I don't see any glaring inefficiency.
-
\$\begingroup\$ I actually eliminated this entire thing by storing a vector at each node that stores the points contained in its subtree. (Memory intensive (takes around 80MB for n = 10^5), but it reduced the time taken by a factor of around 10). Currently the entire code is 50 times slower than what it need it to be. So it has to something in the algorithm. \$\endgroup\$Ishan Bhatnagar– Ishan Bhatnagar2013年02月08日 09:02:57 +00:00Commented Feb 8, 2013 at 9:02
Explore related questions
See similar questions with these tags.