|
7 | 7 | from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier |
8 | 8 | from sklearn.neighbors import KNeighborsClassifier |
9 | 9 | from sklearn.neural_network import MLPClassifier |
| 10 | +from sklearn.naive_bayes import GaussianNB |
| 11 | +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis |
| 12 | +from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis |
10 | 13 | import numpy as np |
11 | 14 | import matplotlib.pyplot as plt |
12 | 15 | import pickle |
@@ -231,3 +234,30 @@ def run_classifier(clfr, x_train_data, y_train_data, x_test_data, y_test_data, a |
231 | 234 | clf = MLPClassifier() |
232 | 235 | run_classifier(clf, X_train, y_train, X_test, y_test, "CNN-MLP Accuracy: {0:0.1f}%", |
233 | 236 | "Multi-layer Perceptron Confusion matrix") |
| 237 | + |
| 238 | +# GaussianNB defaults: |
| 239 | +# priors=None |
| 240 | + |
| 241 | +# classify the images with a Gaussian Naive Bayes Classifier |
| 242 | +print('Gaussian Naive Bayes Classifier starting ...') |
| 243 | +clf = GaussianNB() |
| 244 | +run_classifier(clf, X_train, y_train, X_test, y_test, "CNN-GNB Accuracy: {0:0.1f}%", |
| 245 | + "Gaussian Naive Bayes Confusion matrix") |
| 246 | + |
| 247 | +# LinearDiscriminantAnalysis defaults: |
| 248 | +# solver=’svd’, shrinkage=None, priors=None, n_components=None, store_covariance=False, tol=0.0001 |
| 249 | + |
| 250 | +# classify the images with a Quadratic Discriminant Analysis Classifier |
| 251 | +print('Linear Discriminant Analysis Classifier starting ...') |
| 252 | +clf = LinearDiscriminantAnalysis() |
| 253 | +run_classifier(clf, X_train, y_train, X_test, y_test, "CNN-LDA Accuracy: {0:0.1f}%", |
| 254 | + "Linear Discriminant Analysis Confusion matrix") |
| 255 | + |
| 256 | +# QuadraticDiscriminantAnalysis defaults: |
| 257 | +# priors=None, reg_param=0.0, store_covariance=False, tol=0.0001, store_covariances=None |
| 258 | + |
| 259 | +# classify the images with a Quadratic Discriminant Analysis Classifier |
| 260 | +print('Quadratic Discriminant Analysis Classifier starting ...') |
| 261 | +clf = QuadraticDiscriminantAnalysis() |
| 262 | +run_classifier(clf, X_train, y_train, X_test, y_test, "CNN-QDA Accuracy: {0:0.1f}%", |
| 263 | + "Quadratic Discriminant Analysis Confusion matrix") |
0 commit comments