Note
Go to the end to download the full example code. or to run this example in your browser via JupyterLite or Binder
Neighborhood Components Analysis Illustration#
This example illustrates a learned distance metric that maximizes the nearest neighbors classification accuracy. It provides a visual representation of this metric compared to the original point space. Please refer to the User Guide for more information.
# Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause importmatplotlib.pyplotasplt importnumpyasnp frommatplotlibimport cm fromscipy.specialimport logsumexp fromsklearn.datasetsimport make_classification fromsklearn.neighborsimport NeighborhoodComponentsAnalysis
Original points#
First we create a data set of 9 samples from 3 classes, and plot the points in the original space. For this example, we focus on the classification of point no. 3. The thickness of a link between point no. 3 and another point is proportional to their distance.
X, y = make_classification ( n_samples=9, n_features=2, n_informative=2, n_redundant=0, n_classes=3, n_clusters_per_class=1, class_sep=1.0, random_state=0, ) plt.figure (1) ax = plt.gca () for i in range(X.shape[0]): ax.text(X[i, 0], X[i, 1], str(i), va="center", ha="center") ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4) ax.set_title("Original points") ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) ax.axis("equal") # so that boundaries are displayed correctly as circles deflink_thickness_i(X, i): diff_embedded = X[i] - X dist_embedded = np.einsum ("ij,ij->i", diff_embedded, diff_embedded) dist_embedded[i] = np.inf # compute exponentiated distances (use the log-sum-exp trick to # avoid numerical instabilities exp_dist_embedded = np.exp (-dist_embedded - logsumexp (-dist_embedded)) return exp_dist_embedded defrelate_point(X, i, ax): pt_i = X[i] for j, pt_j in enumerate(X): thickness = link_thickness_i(X, i) if i != j: line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]]) ax.plot(*line, c=cm.Set1(y[j]), linewidth=5 * thickness[j]) i = 3 relate_point(X, i, ax) plt.show ()
Learning an embedding#
We use NeighborhoodComponentsAnalysis to learn an
embedding and plot the points after the transformation. We then take the
embedding and find the nearest neighbors.
nca = NeighborhoodComponentsAnalysis (max_iter=30, random_state=0) nca = nca.fit(X, y) plt.figure (2) ax2 = plt.gca () X_embedded = nca.transform(X) relate_point(X_embedded, i, ax2) for i in range(len(X)): ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i), va="center", ha="center") ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4) ax2.set_title("NCA embedding") ax2.axes.get_xaxis().set_visible(False) ax2.axes.get_yaxis().set_visible(False) ax2.axis("equal") plt.show ()
Total running time of the script: (0 minutes 0.159 seconds)
Related examples
Comparing Nearest Neighbors with and without Neighborhood Components Analysis
Dimensionality Reduction with Neighborhood Components Analysis
Manifold learning on handwritten digits: Locally Linear Embedding, Isomap...
Concentration Prior Type Analysis of Variation Bayesian Gaussian Mixture