Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 76767d2

Browse files
tianyizheng02github-actions
and
github-actions
authored
Consolidate the two existing kNN implementations (TheAlgorithms#8903)
* Add type hints to k_nearest_neighbours.py * Refactor k_nearest_neighbours.py into class * Add documentation to k_nearest_neighbours.py * Use heap-based priority queue for k_nearest_neighbours.py * Delete knn_sklearn.py * updating DIRECTORY.md * Use optional args in k_nearest_neighbours.py for demo purposes * Fix wrong function arg in k_nearest_neighbours.py --------- Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
1 parent 5830b29 commit 76767d2

File tree

3 files changed

+79
-81
lines changed

3 files changed

+79
-81
lines changed

‎DIRECTORY.md‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,6 @@
507507
* [Gradient Descent](machine_learning/gradient_descent.py)
508508
* [K Means Clust](machine_learning/k_means_clust.py)
509509
* [K Nearest Neighbours](machine_learning/k_nearest_neighbours.py)
510-
* [Knn Sklearn](machine_learning/knn_sklearn.py)
511510
* [Linear Discriminant Analysis](machine_learning/linear_discriminant_analysis.py)
512511
* [Linear Regression](machine_learning/linear_regression.py)
513512
* Local Weighted Learning
Lines changed: 79 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,88 @@
1+
"""
2+
k-Nearest Neighbours (kNN) is a simple non-parametric supervised learning
3+
algorithm used for classification. Given some labelled training data, a given
4+
point is classified using its k nearest neighbours according to some distance
5+
metric. The most commonly occurring label among the neighbours becomes the label
6+
of the given point. In effect, the label of the given point is decided by a
7+
majority vote.
8+
9+
This implementation uses the commonly used Euclidean distance metric, but other
10+
distance metrics can also be used.
11+
12+
Reference: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm
13+
"""
14+
115
from collections import Counter
16+
from heapq import nsmallest
217

318
import numpy as np
419
from sklearn import datasets
520
from sklearn.model_selection import train_test_split
621

7-
data = datasets.load_iris()
8-
9-
X = np.array(data["data"])
10-
y = np.array(data["target"])
11-
classes = data["target_names"]
12-
13-
X_train, X_test, y_train, y_test = train_test_split(X, y)
14-
15-
16-
def euclidean_distance(a, b):
17-
"""
18-
Gives the euclidean distance between two points
19-
>>> euclidean_distance([0, 0], [3, 4])
20-
5.0
21-
>>> euclidean_distance([1, 2, 3], [1, 8, 11])
22-
10.0
23-
"""
24-
return np.linalg.norm(np.array(a) - np.array(b))
25-
26-
27-
def classifier(train_data, train_target, classes, point, k=5):
28-
"""
29-
Classifies the point using the KNN algorithm
30-
k closest points are found (ranked in ascending order of euclidean distance)
31-
Params:
32-
:train_data: Set of points that are classified into two or more classes
33-
:train_target: List of classes in the order of train_data points
34-
:classes: Labels of the classes
35-
:point: The data point that needs to be classified
36-
37-
>>> X_train = [[0, 0], [1, 0], [0, 1], [0.5, 0.5], [3, 3], [2, 3], [3, 2]]
38-
>>> y_train = [0, 0, 0, 0, 1, 1, 1]
39-
>>> classes = ['A','B']; point = [1.2,1.2]
40-
>>> classifier(X_train, y_train, classes,point)
41-
'A'
42-
"""
43-
data = zip(train_data, train_target)
44-
# List of distances of all points from the point to be classified
45-
distances = []
46-
for data_point in data:
47-
distance = euclidean_distance(data_point[0], point)
48-
distances.append((distance, data_point[1]))
49-
# Choosing 'k' points with the least distances.
50-
votes = [i[1] for i in sorted(distances)[:k]]
51-
# Most commonly occurring class among them
52-
# is the class into which the point is classified
53-
result = Counter(votes).most_common(1)[0][0]
54-
return classes[result]
22+
23+
class KNN:
24+
def __init__(
25+
self,
26+
train_data: np.ndarray[float],
27+
train_target: np.ndarray[int],
28+
class_labels: list[str],
29+
) -> None:
30+
"""
31+
Create a kNN classifier using the given training data and class labels
32+
"""
33+
self.data = zip(train_data, train_target)
34+
self.labels = class_labels
35+
36+
@staticmethod
37+
def _euclidean_distance(a: np.ndarray[float], b: np.ndarray[float]) -> float:
38+
"""
39+
Calculate the Euclidean distance between two points
40+
>>> KNN._euclidean_distance(np.array([0, 0]), np.array([3, 4]))
41+
5.0
42+
>>> KNN._euclidean_distance(np.array([1, 2, 3]), np.array([1, 8, 11]))
43+
10.0
44+
"""
45+
return np.linalg.norm(a - b)
46+
47+
def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
48+
"""
49+
Classify a given point using the kNN algorithm
50+
>>> train_X = np.array(
51+
... [[0, 0], [1, 0], [0, 1], [0.5, 0.5], [3, 3], [2, 3], [3, 2]]
52+
... )
53+
>>> train_y = np.array([0, 0, 0, 0, 1, 1, 1])
54+
>>> classes = ['A', 'B']
55+
>>> knn = KNN(train_X, train_y, classes)
56+
>>> point = np.array([1.2, 1.2])
57+
>>> knn.classify(point)
58+
'A'
59+
"""
60+
# Distances of all points from the point to be classified
61+
distances = (
62+
(self._euclidean_distance(data_point[0], pred_point), data_point[1])
63+
for data_point in self.data
64+
)
65+
66+
# Choosing k points with the shortest distances
67+
votes = (i[1] for i in nsmallest(k, distances))
68+
69+
# Most commonly occurring class is the one into which the point is classified
70+
result = Counter(votes).most_common(1)[0][0]
71+
return self.labels[result]
5572

5673

5774
if __name__ == "__main__":
58-
print(classifier(X_train, y_train, classes, [4.4, 3.1, 1.3, 1.4]))
75+
import doctest
76+
77+
doctest.testmod()
78+
79+
iris = datasets.load_iris()
80+
81+
X = np.array(iris["data"])
82+
y = np.array(iris["target"])
83+
iris_classes = iris["target_names"]
84+
85+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
86+
iris_point = np.array([4.4, 3.1, 1.3, 1.4])
87+
classifier = KNN(X_train, y_train, iris_classes)
88+
print(classifier.classify(iris_point, k=3))

‎machine_learning/knn_sklearn.py‎

Lines changed: 0 additions & 31 deletions
This file was deleted.

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /