I want a general criticism on this code. Using external modules is not an option, I can only use what comes with CPython.
from collections import Counter
from typing import Sequence, NamedTuple
Coordinates = Sequence[float]
class KNNPoint(NamedTuple):
coords: Coordinates
classif: str
def predict(target: Coordinates, points: Sequence[KNNPoint], k: int) -> str:
'''
Applies the K-Nearest Neighborhood algorithm in order to find the
classification of target.
- target: Single data to be classified.
- points: Collection of data which classifications are known.
- k: The number of closest neighbors to be used.
'''
def distance(p: KNNPoint) -> float:
return sum((a - b) ** 2 for a, b in zip(target, p.coords))
neighbors = sorted(points, key=distance)
counter = Counter(x.classif for x in neighbors[:k])
return counter.most_common(1)[0][0]
If you'd like to run it, this gist has everything ready. it uses a dataset of mobile phones. (This gist shall not be reviewed)
1 Answer 1
Assumption: k << N, where N = len(points)
There is no need to sort the entire list of points!
Instead, take the first k
points, and determine their distance values, and sort them. Then, for each success point:
- determine its distance,
- if it is smaller than the maximum,
- discard the maximum, and insert the new point in the sorted list.
Sorting N points by distance is O(N log N); creating and maintaining a sorted list of k
smallest elements is only O(N log k), which should be considerably faster.
I'm not sure if heapq.nsmallest()
is built into CPython or not ...
k_neighbours = heapq.nsmallest(k, points, key=distance)
counter = Counter(x.classif for x in k_neighbours)
Well, I'm disappointed to see heapq.nsmallest()
performed up to 40% worse that sorted
on CPython, but I'm happy to see PyPy validates my assertion that you don't need to sort the entire list.
Continuing with that thought, bisect.insort()
may be used to maintain a list of the k-nearest neighbours so far:
neighbours = [(float('inf'), None)] * k
for pnt in points:
dist = distance(pnt)
if dist < neighbours[-1][0]:
neighbours.pop()
bisect.insort(neighbours, (dist, pnt))
counter = Counter(pnt.classif for dist, pnt in neighbours)
This gave me 4% speedup over sorted()[:k]
with your gist sample set.
Significant, but not impressive. Still, it was enough encouragement to press on an look for other inefficiencies.
How about the distance()
code. It gets called a lot; can we speed it up? Sure!
def predict(target: Coordinates, points: Sequence[KNNPoint], k: int, *,
sum=sum, zip=zip) -> str:
def distance(p: KNNPoint) -> float:
return sum((a - b) ** 2 for a, b in zip(target, p.coords))
# ...
Instead of searching the global scope for the sum
and zip
functions, they are saved as variables sum
, zip
in the local scope, along with target
, for use in distance()
. Total improvement: 6%.
Applying the same sum=sum, zip=zip
change to the original code, without the bisect.insort() change, also speeds it up by 2%.
Further, adding insort=bisect.insort
to the function declaration, and using insort(neighbours, (dist, pnt))
in the function body also provides a minor improvement.
Finally, I was concerned about neighbours[-1][0]
. Looking up the first tuple of the last element in the array seemed inefficient. We could keep track of this in a local threshold
variable. Final total speedup: 7.7%.
neighbours = [(float('inf'), None)] * k
threshold = neighbours[-1][0]
for pnt in points:
dist = distance(pnt)
if dist < threshold:
neighbours.pop()
insort(neighbours, (dist, pnt))
threshold = neighbours[-1][0]
YMMV
-
\$\begingroup\$ In CPython,
nsmallest
performed equal tosorted
for smallk
; and 40% worst for largek
. Using PyPy,nsmallesr
performed 50% better thansorted
for both small and largek
. \$\endgroup\$Gabriel– Gabriel2018年07月10日 11:48:05 +00:00Commented Jul 10, 2018 at 11:48 -
\$\begingroup\$ And yes, it's built into CPython. \$\endgroup\$Gabriel– Gabriel2018年07月10日 11:55:48 +00:00Commented Jul 10, 2018 at 11:55
-
1\$\begingroup\$ Thanks, I didn't know about
insort
. I just found out that changing(a - b) ** 2
to(a - b) * (a - b)
causes a 28% performance improvement on CPython, while makes no difference on PyPy (I guess PyPy JIT compiles both versions to the same thing). \$\endgroup\$Gabriel– Gabriel2018年07月12日 13:32:52 +00:00Commented Jul 12, 2018 at 13:32 -
1\$\begingroup\$ Wow! Good find! Makes my micro optimizations almost laughable. Nice low hanging fruit! \$\endgroup\$AJNeufeld– AJNeufeld2018年07月12日 13:42:40 +00:00Commented Jul 12, 2018 at 13:42
sklearn.neighbors.KDTree
, which is a better data structure for this than a list and also implemented in C, otherwise. \$\endgroup\$