52
\$\begingroup\$

I'm trying to find the closest point (Euclidean distance) from a user-inputted point to a list of 50,000 points that I have. Note that the list of points changes all the time. and the closest distance depends on when and where the user clicks on the point.

#find the nearest point from a given point to a large list of points
import numpy as np
def distance(pt_1, pt_2):
 pt_1 = np.array((pt_1[0], pt_1[1]))
 pt_2 = np.array((pt_2[0], pt_2[1]))
 return np.linalg.norm(pt_1-pt_2)
def closest_node(node, nodes):
 pt = []
 dist = 9999999
 for n in nodes:
 if distance(node, n) <= dist:
 dist = distance(node, n)
 pt = n
 return pt
a = []
for x in range(50000):
 a.append((np.random.randint(0,1000),np.random.randint(0,1000)))
some_pt = (1, 2)
closest_node(some_pt, a)
Jamal
35.2k13 gold badges134 silver badges238 bronze badges
asked Jul 6, 2013 at 20:19
\$\endgroup\$
3
  • 2
    \$\begingroup\$ I was working on a similar problem and found this. Also, Scipy.spatial.KDTree is the way to go for such approaches. \$\endgroup\$ Commented Sep 27, 2013 at 15:56
  • \$\begingroup\$ The part that says "the list changes all the time" can be expanded a bit, which might hint some ideas to increase the code performance maybe. Is the list updated randomly, or some points are added every some-seconds, and some points are lost? \$\endgroup\$ Commented Jul 12, 2015 at 6:14
  • \$\begingroup\$ The ball tree method in scikit-learn does this efficiently if the same set of points has to be searched through repeatedly. The points are sorted into a tree structure in a preprocessing step to make finding the closest point quicker. scikit-learn.org/stable/modules/generated/… \$\endgroup\$ Commented Jan 17, 2018 at 1:19

3 Answers 3

43
\$\begingroup\$

It will certainly be faster if you vectorize the distance calculations:

def closest_node(node, nodes):
 nodes = np.asarray(nodes)
 dist_2 = np.sum((nodes - node)**2, axis=1)
 return np.argmin(dist_2)

There may be some speed to gain, and a lot of clarity to lose, by using one of the dot product functions:

def closest_node(node, nodes):
 nodes = np.asarray(nodes)
 deltas = nodes - node
 dist_2 = np.einsum('ij,ij->i', deltas, deltas)
 return np.argmin(dist_2)

Ideally, you would already have your list of point in an array, not a list, which will speed things up a lot.

answered Jul 6, 2013 at 23:25
\$\endgroup\$
2
  • 1
    \$\begingroup\$ Thanks for the response, do you mind explaining why the two methods are faster? I'm just curious as I don't come from a CS background \$\endgroup\$ Commented Jul 7, 2013 at 3:28
  • 6
    \$\begingroup\$ Python for loops are very slow. When you run operations using numpy on all items of a vector, there are hidden loops running in C under the hood, which are much, much faster. \$\endgroup\$ Commented Jul 7, 2013 at 3:54
26
\$\begingroup\$

All your code could be rewritten as:

from numpy import random
from scipy.spatial import distance
def closest_node(node, nodes):
 closest_index = distance.cdist([node], nodes).argmin()
 return nodes[closest_index]
a = random.randint(1000, size=(50000, 2))
some_pt = (1, 2)
closest_node(some_pt, a)

You can just write randint(1000) instead of randint(0, 1000), the documentation of randint says:

If high is None (the default), then results are from [0, low).

You can use the size argument to randint instead of the loop and two function calls. So:

a = []
for x in range(50000):
 a.append((np.random.randint(0,1000),np.random.randint(0,1000)))

Becomes:

a = np.random.randint(1000, size=(50000, 2))

It's also much faster (twenty times faster in my tests).


More importantly, scipy has the scipy.spatial.distance module that contains the cdist function:

cdist(XA, XB, metric='euclidean', p=2, V=None, VI=None, w=None)

Computes distance between each pair of the two collections of inputs.

So calculating the distance in a loop is no longer needed.

You use the for loop also to find the position of the minimum, but this can be done with the argmin method of the ndarray object.

Therefore, your closest_node function can be defined simply as:

from scipy.spatial.distance import cdist
def closest_node(node, nodes):
 return nodes[cdist([node], nodes).argmin()]

I've compared the execution times of all the closest_node functions defined in this question:

Original:
1 loop, best of 3: 1.01 sec per loop
Jaime v1:
100 loops, best of 3: 3.32 msec per loop
Jaime v2:
1000 loops, best of 3: 1.62 msec per loop
Mine:
100 loops, best of 3: 2.07 msec per loop

All vectorized functions perform hundreds of times faster than the original solution.

cdist is outperformed only by the second function by Jaime, but only slightly. Certainly cdist is the simplest.

answered Jul 14, 2016 at 22:58
\$\endgroup\$
6
\$\begingroup\$

By using a kd-tree computing the distance between all points it's not needed for this type of query. It's also built in into scipy and can speed up these types of programs enormously going from O(n^2) to O(log n), if you're doing many queries. If you only make one query, the time constructing the tree will dominate the computation time.

from scipy.spatial import KDTree
import numpy as np
n = 10
v = np.random.rand(n, 3)
kdtree = KDTree(v)
d, i = kdtree.query((0,0,0))
print("closest point:", v[i])
answered Feb 11, 2021 at 13:48
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.