I've wrote this code to evaluate a Machine Learning - the classification problem for digits recognition as in the figure below:
For more details and to check the whole code, check the GitHub repository.
Its important to notice that these numbers are in old Indian form, currently they are called Arabic numbers.
This is used to test the number 4 only:
### Naive Bayes Classifier ###
## Done by: Meqdad Darweesh ##
### Importing statements ###
import numpy as np
from scipy.stats import multivariate_normal as mvn
from sklearn.model_selection import KFold
### Training data ###
# First class
c4 = np.array([ [8,11,12,12,13,16,20,20,22,23,22,20,19,18,16,16,15,12,5,5,5,5,5,5,5,4,4,4,3,2],
[3,10,18,20,21,23,24,21,21,22,21,20,18,16,15,14,12,12,6,4,4,4,4,4,4,3,3,3,3,1],
[2,7,8,11,19,20,22,22,23,22,22,23,22,20,18,18,20,20,18,16,14,11,11,11,9,8,4,4,4,3],
[3,12,14,19,22,23,26,23,21,21,19,18,18,17,16,15,15,13,9,9,9,8,8,5,5,4,4,4,3,1],
[3,6,8,15,18,21,24,26,25,24,24,23,21,21,20,17,17,16,16,14,11,11,11,11,9,8,8,8,5,1],
[4,8,15,19,24,25,25,26,27,25,22,21,20,19,19,17,14,11,12,13,13,13,13,13,12,10,10,9,4,3],
[5,12,15,22,22,25,25,24,21,19,19,19,18,18,18,17,14,15,13,9,9,5,5,5,5,5,5,4,4,1],
[4,9,9,9,10,17,18,18,19,21,22,20,20,16,17,16,14,12,11,8,7,7,7,7,5,4,3,3,3,1],
[0,7,8,9,16,20,20,21,21,22,23,23,23,20,19,17,16,16,13,10,9,8,8,8,7,4,4,4,3,1],
[2,7,8,9,11,11,12,17,21,22,22,21,19,19,17,16,14,11,8,8,8,8,7,4,4,4,4,4,4,2],
[7,16,19,23,24,24,21,19,18,17,16,14,14,12,11,8,8,8,7,7,4,4,4,5,4,4,4,4,4,2],
[4,11,15,18,21,22,22,22,21,20,18,16,15,14,12,7,5,5,5,5,5,4,4,4,4,4,4,4,4,2],
[1,6,14,17,22,23,24,27,27,27,23,20,20,18,17,17,15,15,10,10,9,8,7,5,4,4,3,3,3,2],
[5,9,11,16,22,22,23,24,23,22,16,15,10,10,11,11,10,10,9,5,4,4,3,4,4,4,4,4,4,2],
[2,7,9,21,22,24,26,26,26,22,20,18,18,16,15,14,12,8,8,4,4,4,4,4,4,4,3,3,2,1],
[0,10,13,16,19,21,25,26,24,22,19,17,17,15,15,13,11,7,6,6,5,3,4,4,4,3,3,3,2,1],
[3,5,5,12,15,19,21,23,24,24,24,23,20,18,17,17,15,13,10,9,8,8,7,7,7,7,6,2,2,1],
[5,7,15,18,18,20,21,20,19,18,17,17,17,16,15,15,15,13,12,12,11,9,9,4,3,3,3,3,3,3],
[4,14,18,21,24,23,22,22,20,18,16,14,12,10,6,6,4,4,4,4,4,4,4,4,4,4,4,3,3,1],
[5,7,9,14,18,18,22,22,22,20,18,17,15,13,12,10,7,4,4,4,4,4,4,3,3,3,3,3,3,3],
[5,7,8,13,17,19,19,21,21,22,20,17,15,14,13,10,7,5,5,5,5,4,4,4,4,4,4,4,4,3],
[5,8,9,16,19,19,20,23,21,18,18,12,11,10,8,4,4,4,4,4,4,3,3,3,3,3,3,3,3,1],
[3,15,18,20,23,23,20,22,19,14,14,14,13,11,8,5,3,3,3,3,3,3,3,2,2,2,2,2,2,1],
[6,13,17,20,20,21,22,20,20,18,19,19,16,16,14,14,14,11,5,5,3,3,3,3,3,2,2,2,2,1],
[5,8,8,18,19,20,20,21,22,19,17,15,14,13,10,8,5,5,5,4,4,4,4,4,3,3,3,3,3,3],
[5,9,10,12,9,8,7,7,15,18,20,20,17,12,11,9,4,4,3,3,3,3,3,3,3,3,3,3,3,2],
[5,6,13,17,19,22,21,21,22,21,20,18,18,17,14,8,8,8,9,9,7,4,4,4,3,3,3,3,3,3],
[4,10,12,16,20,21,21,19,18,17,18,18,17,16,15,14,12,11,10,6,6,4,4,3,3,3,3,3,3,3],
[5,10,14,18,22,19,21,18,18,19,19,18,15,15,12,4,4,4,3,3,3,3,3,3,3,3,3,3,3,1],
[6,9,11,12,10,10,13,16,18,21,20,19,8,5,4,4,4,4,4,4,4,4,4,4,3,3,3,3,3,3],
[4,8,12,17,18,21,21,21,17,16,15,13,7,8,8,7,7,4,4,4,3,3,3,3,4,4,3,3,3,2],
[3,7,12,17,19,20,22,20,20,19,19,18,17,16,16,15,14,13,12,9,4,4,4,3,3,3,3,3,2,1],
[2,5,8,10,10,11,11,10,13,17,19,20,22,22,20,16,15,15,13,11,8,3,3,3,3,3,3,3,2,1],
[4,8,10,11,10,15,15,17,18,19,18,20,18,17,15,13,12,7,4,4,4,4,4,4,4,4,3,3,3,2],
[2,8,12,15,18,20,19,20,21,21,23,19,19,16,16,16,14,12,10,7,7,7,7,6,3,3,3,3,2,1],
[2,13,17,18,21,22,20,18,18,17,17,15,13,11,8,8,4,4,4,4,4,4,4,4,4,4,4,4,3,1],
[6,6,9,14,15,18,20,20,22,20,16,16,15,11,8,8,8,5,4,4,4,4,4,4,4,5,5,5,5,4],
[8,13,16,20,20,20,19,17,17,17,17,15,14,13,10,6,3,3,3,4,4,4,3,3,4,3,3,3,2,2],
[5,9,17,18,19,18,17,16,14,13,12,12,11,10,4,4,4,3,3,3,3,3,3,3,4,4,3,3,3,3],
[4,6,8,11,16,17,18,20,16,17,16,17,17,16,14,12,12,10,9,9,8,8,6,4,3,3,3,2,2,2] ])
# Second class
c7 = np.array([ [5,7,7,8,8,9,9,9,8,8,7,8,8,8,9,8,9,9,9,9,9,9,10,11,12,16,25,29,29,11],
[6,6,6,7,6,6,5,6,6,6,5,6,6,7,8,8,8,8,8,8,7,7,8,8,11,17,26,27,27,19],
[3,7,7,7,7,6,7,7,7,7,8,7,8,8,9,9,9,10,9,8,8,8,9,9,12,16,24,28,28,6],
[3,7,7,8,9,9,9,8,9,8,8,8,7,7,8,7,7,7,7,8,9,8,10,11,14,20,26,26,15,3],
[5,6,7,8,7,8,7,7,7,7,7,8,8,8,7,7,6,6,7,6,7,7,7,7,8,8,25,27,30,21],
[2,7,7,7,8,8,9,9,10,10,10,9,8,9,8,7,8,9,11,10,10,13,15,17,14,16,14,13,12,4],
[6,7,7,7,7,7,9,9,8,8,6,6,5,6,6,8,8,8,10,9,9,9,9,11,16,21,22,18,14,10],
[4,6,6,6,5,6,6,6,8,7,7,7,8,7,7,7,7,7,8,8,8,10,16,17,15,14,10,8,7,4],
[5,8,9,9,7,6,7,6,7,6,7,7,7,6,6,6,6,6,7,8,8,7,9,9,9,24,30,30,30,15],
[7,8,9,5,6,5,5,5,6,6,6,6,5,5,5,5,5,5,6,7,7,7,7,7,8,12,24,30,30,20],
[4,5,5,5,6,5,6,6,6,6,5,5,6,6,6,6,7,6,6,6,8,7,7,9,16,23,26,19,9,5],
[1,7,8,8,7,6,7,6,6,6,6,6,6,6,6,6,6,6,6,7,7,8,9,11,14,17,17,18,14,8],
[7,8,8,8,7,7,5,6,5,5,5,5,6,6,7,7,7,7,7,8,8,9,9,10,17,26,29,24,17,13],
[2,3,8,8,7,7,7,6,6,7,7,6,7,6,6,7,7,8,8,8,8,10,10,21,28,30,30,30,8,1],
[8,9,9,9,7,7,8,7,8,8,9,9,8,8,7,9,9,9,8,8,9,10,11,14,17,20,14,12,9,1],
[4,7,7,7,7,9,8,9,9,9,8,9,9,10,9,11,9,10,10,11,11,11,15,15,11,12,15,13,13,6],
[5,5,6,7,7,6,6,7,8,9,10,11,11,11,11,10,9,10,11,12,12,14,15,15,14,13,13,11,8,5],
[6,6,7,7,7,7,8,8,8,9,9,10,10,11,12,13,14,15,15,15,17,20,20,18,16,13,9,7,5,3],
[1,6,6,6,7,9,9,8,8,7,6,6,6,6,6,6,7,7,8,9,8,8,8,8,9,27,30,30,30,4],
[2,5,5,5,5,6,6,6,6,6,7,7,7,7,8,7,7,7,7,8,8,7,8,9,28,28,30,30,29,5],
[4,7,7,6,7,7,7,7,6,6,6,6,5,6,6,7,6,8,9,9,9,9,10,12,15,16,17,15,8,8],
[5,6,5,5,6,6,7,7,7,9,9,8,8,8,8,6,7,6,7,6,8,7,7,10,16,19,14,14,10,6],
[5,5,7,8,9,8,6,6,6,6,7,7,7,7,7,8,8,9,8,8,9,9,12,14,16,13,14,15,12,10],
[5,5,6,7,6,6,5,5,6,7,8,8,10,12,11,12,11,11,11,11,12,15,15,14,11,10,10,9,7,5],
[1,6,7,8,9,8,8,7,9,9,9,9,8,8,8,8,8,8,8,9,11,11,14,16,21,21,18,16,13,3],
[2,6,7,7,6,6,5,6,7,7,8,9,9,9,10,10,10,9,9,9,9,9,10,14,13,14,13,13,10,5],
[6,7,7,7,8,7,7,7,6,7,7,7,8,8,8,8,9,7,8,9,10,12,16,18,17,13,13,10,8,6],
[6,6,7,7,8,7,7,7,7,8,8,9,9,9,9,11,11,10,10,11,13,11,9,10,11,12,12,11,9,6],
[4,7,8,8,9,10,10,8,7,8,7,8,9,8,10,10,9,11,10,9,8,10,12,23,20,17,13,12,11,5],
[1,7,8,8,8,8,7,7,8,8,8,9,10,10,11,12,13,12,12,13,14,11,10,10,11,11,11,8,7,3],
[5,6,7,8,9,9,7,8,7,8,8,7,8,7,8,8,11,11,11,11,12,15,14,11,11,11,11,10,8,6],
[5,6,8,8,9,9,8,7,7,7,7,8,8,10,9,9,11,12,11,11,12,13,10,8,7,9,9,9,9,3],
[7,7,7,8,8,8,8,9,7,8,8,8,8,8,8,8,8,7,7,8,8,10,13,15,20,18,19,14,9,5],
[2,6,6,7,6,6,6,6,6,7,6,6,7,7,7,8,7,8,8,9,11,12,12,15,18,20,17,16,12,1],
[5,7,7,8,7,7,7,7,7,6,7,7,8,7,7,8,8,9,9,8,9,8,9,10,17,18,22,19,13,8],
[2,4,6,6,6,6,5,6,7,8,8,8,9,10,10,10,10,10,10,11,13,17,20,23,26,24,15,13,3,2],
[2,4,6,8,8,9,10,11,8,9,7,8,7,7,8,8,8,8,8,9,10,12,16,20,23,27,28,23,9,6],
[2,5,5,5,6,7,9,10,10,9,8,8,7,8,8,7,7,8,8,9,9,11,11,12,19,25,25,20,16,1],
[5,7,8,8,8,8,9,8,9,8,8,8,8,8,8,8,10,9,11,11,11,14,16,21,24,27,21,16,9,4],
[3,6,7,8,9,8,7,8,8,7,8,7,7,6,8,6,7,7,7,7,7,7,9,23,29,30,29,26,8,2] ])
# Third class
c9 = np.array([[4,7,9,11,12,13,12,10,9,10,9,9,10,10,9,9,12,15,16,18,20,21,23,27,22,18,10,9,6,3],
[7,8,10,11,10,11,13,10,11,10,10,10,10,9,9,10,13,13,13,15,16,20,20,17,16,14,13,13,7,3],
[7,9,10,10,9,7,8,7,8,10,10,11,12,12,13,14,16,15,12,9,9,9,8,9,5,5,5,4,4,3],
[5,10,11,12,12,10,10,9,9,9,8,9,9,13,13,14,15,15,15,15,16,16,15,11,10,9,7,6,5,4],
[6,8,11,11,13,9,10,9,9,8,7,7,8,8,12,16,16,14,10,10,9,8,8,7,8,7,7,7,6,5],
[9,12,14,16,11,9,8,10,9,8,8,8,9,8,8,8,12,13,14,13,13,13,11,11,9,8,8,8,7,5],
[7,9,11,12,13,9,7,9,8,8,8,8,8,7,7,17,16,15,15,13,13,11,13,12,12,13,12,10,8,7],
[6,8,9,11,10,10,9,8,9,8,8,9,9,9,10,8,8,11,16,17,17,15,14,11,8,10,9,8,7,6],
[8,10,12,12,11,10,10,10,9,10,10,8,8,9,14,14,14,14,17,19,17,11,9,8,7,6,7,6,5,3],
[8,12,14,16,16,12,11,10,9,9,9,9,9,14,17,17,16,15,15,15,10,9,6,6,6,6,7,7,6,5],
[7,10,12,13,11,11,9,9,9,8,8,9,9,9,12,11,12,10,11,12,12,11,12,9,7,6,6,6,4,3],
[8,11,14,16,13,10,9,9,9,9,9,9,9,14,16,14,12,9,8,8,9,9,9,8,8,8,8,8,7,6],
[5,7,11,12,11,10,9,9,8,9,8,8,9,10,10,13,14,11,9,10,11,9,8,7,8,7,7,6,5,4],
[3,9,10,12,13,12,11,9,10,10,10,10,10,10,11,12,14,13,13,14,15,14,12,11,11,9,7,6,5,2],
[6,9,11,14,15,14,11,11,9,8,9,10,15,16,15,11,12,11,12,11,12,9,8,6,5,5,5,5,4,4],
[3,8,10,10,12,11,11,9,9,9,9,7,7,8,8,9,10,11,14,14,15,15,14,14,12,14,10,9,8,5],
[7,10,11,14,15,11,11,10,10,10,9,10,11,13,15,13,12,12,12,13,10,10,8,8,6,7,7,6,5,3],
[5,8,9,10,13,11,9,9,8,8,8,8,7,8,9,11,13,14,16,14,13,12,17,17,17,16,12,8,7,3],
[9,9,13,15,10,9,8,7,7,7,6,6,8,14,15,17,16,12,10,10,11,11,9,7,7,7,7,7,6,5],
[5,9,12,13,14,11,10,9,8,8,9,9,11,13,16,15,14,14,15,14,14,14,8,7,7,6,5,4,3,0],
[7,10,11,14,14,11,10,11,9,8,9,9,10,11,14,15,14,13,12,13,14,16,19,18,15,14,12,9,7,3],
[6,8,8,11,11,10,9,10,8,9,9,8,9,10,11,13,13,12,12,13,12,11,11,10,10,9,9,7,6,5],
[4,8,9,12,13,13,12,12,10,11,10,10,11,11,10,11,15,16,15,17,17,18,13,12,10,9,9,8,7,4],
[7,9,9,12,12,13,10,10,10,9,9,10,10,10,10,10,14,14,12,10,10,11,13,12,13,16,15,12,10,8],
[7,10,11,12,14,11,11,10,10,10,10,10,10,14,15,15,14,12,11,11,13,14,13,12,10,10,9,8,6,4],
[7,9,12,12,13,13,11,11,11,8,9,10,8,8,8,14,14,17,17,18,19,21,19,14,13,13,11,8,5,3],
[5,10,11,12,12,13,12,11,12,11,11,11,11,10,10,13,13,15,15,15,17,14,13,12,12,13,12,10,8,5],
[8,9,12,13,11,10,9,10,9,9,9,9,9,11,12,14,13,13,11,10,9,8,7,8,7,7,7,6,4,3],
[5,9,10,12,13,11,11,10,11,10,9,8,7,8,8,13,14,13,13,12,12,12,11,9,9,9,9,9,7,4],
[8,9,13,14,13,11,11,9,8,9,9,9,9,9,14,15,13,12,11,11,11,9,9,7,8,8,8,7,5,4],
[6,7,10,13,14,15,13,12,11,11,11,10,9,10,11,15,13,12,11,12,13,12,12,11,8,7,6,5,5,4],
[5,9,10,11,12,13,11,9,9,10,9,9,9,9,9,9,10,13,15,14,17,18,19,19,16,15,10,7,6,3],
[8,8,10,12,13,12,11,11,10,10,9,8,9,11,12,13,12,10,10,11,9,10,10,11,11,11,8,6,5,4],
[7,9,10,11,13,14,14,12,11,11,11,12,13,16,16,16,15,15,14,14,15,14,13,11,10,9,7,6,4,3],
[3,7,10,12,13,13,13,11,11,10,10,10,11,11,13,15,16,16,15,15,14,17,16,17,18,16,14,12,8,5],
[6,8,9,11,12,12,10,11,11,10,9,9,9,8,9,11,16,15,15,13,13,13,13,12,13,12,11,8,6,4],
[5,7,9,12,12,13,12,10,10,9,11,11,9,10,9,11,12,13,14,13,13,13,17,18,21,19,15,10,8,5],
[4,7,8,10,11,12,12,14,11,11,11,11,11,12,10,11,15,16,19,26,28,28,28,28,27,22,20,16,4,4],
[4,7,10,11,12,12,13,11,11,10,10,10,10,11,11,10,10,12,16,17,20,21,22,18,16,13,9,7,5,2],
[1,5,11,12,14,15,15,12,11,9,9,9,8,10,10,11,14,15,17,16,17,18,19,16,15,13,10,9,5,4] ])
### 4-fold cross-validation ###
kf = KFold(n_splits=4)
i4 = 0
i7 = 0
i9 = 0
c4_error = 0
c7_error = 0
c9_error = 0
final_accuracy = 0
print("\n")
print("...::: Class 4 cross-validation :::...")
for train_index, test_index in kf.split(c4):
i4+=1
print("\n")
print("The iteration #", i4)
X_train4, X_test4 = c4[train_index], c4[test_index]
X_train7, X_test7 = c7[train_index], c7[test_index]
X_train9, X_test9 = c9[train_index], c9[test_index]
X_train4_mean = np.mean(X_train4, axis=0)
X_train7_mean = np.mean(X_train7, axis=0)
X_train9_mean = np.mean(X_train9, axis=0)
v4 = mvn(X_train4_mean, cov = np.cov(X_train4.T) + np.eye(30))
v7 = mvn(X_train7_mean, cov = np.cov(X_train7.T) + np.eye(30))
v9 = mvn(X_train9_mean, cov = np.cov(X_train9.T) + np.eye(30))
res4 = v4.pdf(X_test4)
res7 = v7.pdf(X_test4)
res9 = v9.pdf(X_test4)
print("\n")
print("In 4 - res4: ", res4)
print("In 4 - res7: ", res7)
print("In 4 - res9: ", res9)
for x in range(0, 10):
if res4[x] > res7[x]:
if res4[x] > res9[x]:
print("Sample",x, "is belongs to class 4")
elif res7[x] > res4[x]:
if res7[x] > res9[x]:
c4_error +=1
print("Sample",x, "is belongs to class 7, Error")
elif res9[x] > res4[x]:
if res9[x] > res7[x]:
c4_error += 1
print("Sample",x, "is belongs to class 9, Error")
print("\n")
if c4_error == 0:
print("Average accuracy C4: 100% ")
else:
print("Average accuracy C4: %", (c4_error / 10) * 100)
print("\n")
### The End ###
1 Answer 1
DRY
There is a whole lot of repetition in this code. Especially if you look at the complete code for the 3 classes. Imagine you need to add a fourth class...
The way to tackle this is to separate the program in logical chunks. Each performing a part, and glueing those together.
An item in this is to return something useful instead of printing an intermediate result. If you need the intermediate results in the text, I suggest to look at the logging builtin module
Code
Let's work step by step.
sample definition
Instead of hardcoding different variable for the different digits, I would use a dict:
samples = {
'4': c4,
'7': c7,
'9': c9,
}
split the samples
... along the train- and testindex:
def split_sample(samples, train_index, test_index):
for sample_name, sample in samples.items():
yield sample_name, (sample[train_index], sample[test_index])
... splits the sample in training and test data.
make the distribution:
uses the training data to generate the mvn:
def make_distribution(x_train):
x_train_mean = np.mean(x_train, axis=0)
covariance = np.cov(x_train.T) + np.eye(30)
return mvn(x_train_mean, cov=covariance)
make the pdfs
def make_pdf(samples, sample_name):
kf = KFold(n_splits=4)
split = kf.split(samples[sample_name])
for iteration, (train_index, test_index) in enumerate(split, 1):
samples_splits = dict(split_sample(samples, train_index, test_index))
distributions = (
(sample_name, make_distribution(x_train))
for sample_name, (x_train, x_test) in samples_splits.items()
)
x_test = samples_splits[sample_name][1]
yield iteration, {
sample_name: distribution.pdf(x_test)
for sample_name, distribution in distributions
}
find which class
Here your code contains a magic number 10. I assume the 10 is the length of the testdata. Try to avoid magic numbers like that as much as possible. Here you can solve it by using zip.
Instead of for x in range(0, 10):, it is better to use enumerate:
def find_class(pdfs):
fits = zip(*([(x, name) for x in pdf] for name, pdf in pdfs.items()))
for x, fit in enumerate(fits):
order = [name for _, name in sorted(fit, reverse=True)]
yield x, order
This method sorts the different pdfs at the different x-points from high to low. In case of an ex aequo, it sorts alphabetically. This is easier than using the nested ifs.
Instead of dividing it manually by 10, use sum and len.
gluing it together:
To calculate the accuracy, instead of dividing it manually by 10, use sum and len:
for sample_name in samples.keys():
for iteration, pdfs in make_pdf(samples, sample_name):
classes = dict(find_class(pdfs))
# print(classes)
matches = [order[0] == sample_name for order in classes.values()]
# print(matches)
accuracy = sum(matches) / len(classes)
print(f'accuracy for `{sample_name}` in iteration {iteration}: {accuracy}')
Testing
By doing it like this, you can test each part of the code individually.
main guard
It is best to put all of the script-specific code (not the functions, but the definition of samples for example) behind a if __name__ == '__main__': guard, so you can import this later in another script or program.
VI suppose. Where did the values in your numpy arrays (you call them a class, they're not a class) come from? What's your current average accuracy on the data set you used? Is the data set available somewhere? \$\endgroup\$