Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 4379c64

Browse files
author
Algorithmica
authored
Add files via upload
1 parent 56c2a67 commit 4379c64

File tree

8 files changed

+1626
-0
lines changed

8 files changed

+1626
-0
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import sys
2+
import os
3+
path = os.path.abspath(os.path.join('.'))
4+
sys.path.append(path)
5+
6+
import os
7+
import pandas as pd
8+
import seaborn as sns
9+
import numpy as np
10+
import math
11+
from itertools import product, cycle
12+
from sklearn import covariance, preprocessing, tree, svm, neighbors, metrics, linear_model, manifold, linear_model
13+
from sklearn_pandas import DataFrameMapper,CategoricalImputer
14+
from sklearn import model_selection, ensemble, preprocessing, decomposition, feature_selection
15+
import matplotlib.pyplot as plt
16+
from matplotlib.colors import ListedColormap
17+
from mpl_toolkits.mplot3d import Axes3D
18+
from sklearn.datasets import make_circles, make_moons, make_classification,make_blobs
19+
import matplotlib.cm as cm
20+
from classification_utils import *
21+
22+
23+
def grid_search_plot_one_parameter_curves_clustering(estimator, grid, X, scoring):
24+
name = str(estimator)
25+
items = sorted(grid.items())
26+
keys, values = zip(*items)
27+
params = []
28+
scores = []
29+
for v in product(*values):
30+
params.append(dict(zip(keys, v)))
31+
for param in params:
32+
estimator.set_params(**param)
33+
if(name.startswith('Gaussian')):
34+
labels = estimator.fit_predict(X)
35+
else:
36+
estimator.fit(X)
37+
labels = estimator.labels_
38+
39+
if scoring == 's_score':
40+
score = metrics.silhouette_score(X, labels, metric='euclidean')
41+
elif scoring == 'ch_score':
42+
score = metrics.calinski_harabaz_score(X, labels)
43+
else:
44+
print(str(scoring) +' metric not supported')
45+
return
46+
scores.append(score)
47+
48+
plt.figure()
49+
plt.plot(params, scores, marker="D")
50+
plt.xlabel('nclusters')
51+
plt.ylabel(str(scoring))
52+
53+
def grid_search_best_model_clustering(estimator, grid, X, scoring):
54+
name = str(estimator)
55+
items = sorted(grid.items())
56+
keys, values = zip(*items)
57+
params =[]
58+
for v in product(*values):
59+
params.append(dict(zip(keys, v)))
60+
n = len(params)
61+
best_param = None
62+
best_score = 0.0
63+
for param in params:
64+
estimator.set_params(**param)
65+
if(name.startswith('Gaussian')):
66+
labels = estimator.fit_predict(X)
67+
else:
68+
estimator.fit(X)
69+
labels = estimator.labels_
70+
if scoring == 's_score':
71+
score = metrics.silhouette_score(X, labels, metric='euclidean')
72+
elif scoring == 'ch_score':
73+
score = metrics.calinski_harabaz_score(X, labels)
74+
else:
75+
print(scoring+' metric not supported')
76+
break
77+
if score > best_score :
78+
best_score = score
79+
best_param = param
80+
if best_param is not None:
81+
estimator.set_params(**best_param)
82+
estimator.fit(X)
83+
print("Best score:" + str(best_score))
84+
return estimator
85+
else:
86+
return None
87+
88+
def grid_search_plot_models_2d_clustering(estimator, grid, X, xlim=None, ylim=None):
89+
plt.style.use('seaborn')
90+
items = sorted(grid.items())
91+
keys, values = zip(*items)
92+
params =[]
93+
for v in product(*values):
94+
params.append(dict(zip(keys, v)))
95+
n = len(params)
96+
fig, axes = plt.subplots(int(math.sqrt(n)), math.ceil(math.sqrt(n)), figsize=(20, 20), dpi=80)
97+
axes = np.array(axes)
98+
for ax, param in zip(axes.reshape(-1), params):
99+
estimator.set_params(**param)
100+
estimator.fit(X)
101+
plot_model_2d_clustering(estimator, X, ax, xlim, ylim, str(param), False)
102+
plt.tight_layout()
103+
104+
def grid_search_plot_models_3d_clustering(estimator, grid, X, xlim=None, ylim=None, zlim=None):
105+
plt.style.use('seaborn')
106+
items = sorted(grid.items())
107+
keys, values = zip(*items)
108+
params =[]
109+
for v in product(*values):
110+
params.append(dict(zip(keys, v)))
111+
n = len(params)
112+
fig, axes = plt.subplots(int(math.sqrt(n)), math.ceil(math.sqrt(n)), figsize=(20, 20), dpi=80, subplot_kw=dict(projection='3d') )
113+
axes = np.array(axes)
114+
for ax, param in zip(axes.reshape(-1), params):
115+
estimator.set_params(**param)
116+
estimator.fit(X)
117+
plot_model_3d_clustering(estimator, X, ax, xlim, ylim, zlim, str(param), False)
118+
plt.tight_layout()
119+
120+
def plot_model_3d_clustering(estimator, X, ax = None, xlim=None, ylim=None, zlim=None, title=None, new_window=True, rotation=False):
121+
name = str(estimator)
122+
if(name.startswith('Gaussian')):
123+
y = estimator.fit_predict(X)
124+
else:
125+
y = estimator.labels_
126+
127+
ax = plot_data_3d_classification(X, y, ax, xlim, ylim, zlim, title, new_window, rotation)
128+
if hasattr(estimator, 'cluster_centers_'):
129+
centers = estimator.cluster_centers_
130+
plot_data_3d(centers, ax, new_window=False, title=title, s=200)
131+
132+
def plot_model_2d_clustering(estimator, X, ax = None, xlim=None, ylim=None, title=None, new_window=True):
133+
name = str(estimator)
134+
if(name.startswith('Gaussian')):
135+
y = estimator.fit_predict(X)
136+
else:
137+
y = estimator.labels_
138+
139+
ax = plot_data_2d_classification(X, y, ax, xlim, ylim, title, new_window)
140+
if hasattr(estimator, 'cluster_centers_'):
141+
centers = estimator.cluster_centers_
142+
plot_data_2d(centers, ax, new_window=False, title = title, s=200)
143+
144+
145+
def generate_synthetic_data_2d_clusters(n_samples, n_centers, cluster_std) :
146+
return make_blobs(n_samples=n_samples, centers=n_centers,
147+
cluster_std=cluster_std, random_state=0)
148+
149+
def generate_synthetic_data_3d_clusters(n_samples, n_centers, cluster_std) :
150+
return make_classification(n_samples = n_samples,
151+
n_features = 3,
152+
n_informative = 3,
153+
n_clusters_per_class=1,
154+
n_redundant = 0,
155+
n_classes = n_centers,
156+
random_state=100)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /