I'm trying to find the closest point (Euclidean distance) from a user-inputted point to a list of 50,000 points that I have. Note that the list of points changes all the time. and the closest distance depends on when and where the user clicks on the point.
#find the nearest point from a given point to a large list of points
import numpy as np
def distance(pt_1, pt_2):
pt_1 = np.array((pt_1[0], pt_1[1]))
pt_2 = np.array((pt_2[0], pt_2[1]))
return np.linalg.norm(pt_1-pt_2)
def closest_node(node, nodes):
pt = []
dist = 9999999
for n in nodes:
if distance(node, n) <= dist:
dist = distance(node, n)
pt = n
return pt
a = []
for x in range(50000):
a.append((np.random.randint(0,1000),np.random.randint(0,1000)))
some_pt = (1, 2)
closest_node(some_pt, a)
3 Answers 3
It will certainly be faster if you vectorize the distance calculations:
def closest_node(node, nodes):
nodes = np.asarray(nodes)
dist_2 = np.sum((nodes - node)**2, axis=1)
return np.argmin(dist_2)
There may be some speed to gain, and a lot of clarity to lose, by using one of the dot product functions:
def closest_node(node, nodes):
nodes = np.asarray(nodes)
deltas = nodes - node
dist_2 = np.einsum('ij,ij->i', deltas, deltas)
return np.argmin(dist_2)
Ideally, you would already have your list of point in an array, not a list, which will speed things up a lot.
-
1\$\begingroup\$ Thanks for the response, do you mind explaining why the two methods are faster? I'm just curious as I don't come from a CS background \$\endgroup\$dassouki– dassouki2013年07月07日 03:28:57 +00:00Commented Jul 7, 2013 at 3:28
-
6\$\begingroup\$ Python for loops are very slow. When you run operations using numpy on all items of a vector, there are hidden loops running in C under the hood, which are much, much faster. \$\endgroup\$Jaime– Jaime2013年07月07日 03:54:36 +00:00Commented Jul 7, 2013 at 3:54
All your code could be rewritten as:
from numpy import random
from scipy.spatial import distance
def closest_node(node, nodes):
closest_index = distance.cdist([node], nodes).argmin()
return nodes[closest_index]
a = random.randint(1000, size=(50000, 2))
some_pt = (1, 2)
closest_node(some_pt, a)
You can just write randint(1000)
instead of randint(0, 1000)
, the documentation of randint
says:
If
high
isNone
(the default), then results are from[0, low)
.
You can use the size
argument to randint
instead of the loop and two function calls. So:
a = []
for x in range(50000):
a.append((np.random.randint(0,1000),np.random.randint(0,1000)))
Becomes:
a = np.random.randint(1000, size=(50000, 2))
It's also much faster (twenty times faster in my tests).
More importantly, scipy
has the scipy.spatial.distance
module that contains the cdist
function:
cdist(XA, XB, metric='euclidean', p=2, V=None, VI=None, w=None)
Computes distance between each pair of the two collections of inputs.
So calculating the distance
in a loop is no longer needed.
You use the for loop also to find the position of the minimum, but this can be done with the argmin
method of the ndarray
object.
Therefore, your closest_node
function can be defined simply as:
from scipy.spatial.distance import cdist
def closest_node(node, nodes):
return nodes[cdist([node], nodes).argmin()]
I've compared the execution times of all the closest_node
functions defined in this question:
Original:
1 loop, best of 3: 1.01 sec per loop
Jaime v1:
100 loops, best of 3: 3.32 msec per loop
Jaime v2:
1000 loops, best of 3: 1.62 msec per loop
Mine:
100 loops, best of 3: 2.07 msec per loop
All vectorized functions perform hundreds of times faster than the original solution.
cdist
is outperformed only by the second function by Jaime, but only slightly.
Certainly cdist
is the simplest.
By using a kd-tree computing the distance between all points it's not needed for this type of query. It's also built in into scipy and can speed up these types of programs enormously going from O(n^2) to O(log n), if you're doing many queries. If you only make one query, the time constructing the tree will dominate the computation time.
from scipy.spatial import KDTree
import numpy as np
n = 10
v = np.random.rand(n, 3)
kdtree = KDTree(v)
d, i = kdtree.query((0,0,0))
print("closest point:", v[i])
Explore related questions
See similar questions with these tags.
Scipy.spatial.KDTree
is the way to go for such approaches. \$\endgroup\$