I have applied the KNN algorithm for classifying handwritten digits. the digits are in vector format initially 8*8, and stretched to form a vector 1*64..
As it stands my code applies the kNN algorithm letting the user decide on the k input. The training dataset can be found here and the validation set here.
ImageMatrix.java
import java.util.*;
public class ImageMatrix {
private int[] data;
private int classCode;
private int curData;
public ImageMatrix(int[] data, int classCode) {
assert data.length == 64; //maximum array length of 64
this.data = data;
this.classCode = classCode;
}
public String toString() {
return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable
}
public int[] getData() {
return data;
}
public int getClassCode() {
return classCode;
}
public int getCurData() {
return curData;
}
}
ImageMatrixDB.java
import java.util.Scanner;
import java.util.Map.Entry;
import java.util.*;
import java.io.*;
import java.util.ArrayList;
public class ImageMatrixDB implements Iterable<ImageMatrix> {
private List<ImageMatrix> list = new ArrayList<ImageMatrix>();
public ImageMatrixDB load(String f) throws IOException {
try (
FileReader fr = new FileReader(f);
BufferedReader br = new BufferedReader(fr)) {
String line = null;
while((line = br.readLine()) != null) {
int lastComma = line.lastIndexOf(',');
int classCode = Integer.parseInt(line.substring(1 + lastComma));
int[] data = Arrays.stream(line.substring(0, lastComma).split(","))
.mapToInt(Integer::parseInt)
.toArray();
ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9..
list.add(matrix);
}
}
return this;
}
public void printResults(){ //output results
for(ImageMatrix matrix: list){
System.out.println(matrix);
}
}
public Iterator<ImageMatrix> iterator() {
return this.list.iterator();
}
/// kNN implementation ///
public static int distance(int[] a, int[] b) {
int sum = 0;
for(int i = 0; i < a.length; i++) {
sum += (a[i] - b[i]) * (a[i] - b[i]);
}
return (int)Math.sqrt(sum);
}
public static int classify(ImageMatrixDB trainingSet, int[] curData, int k) { //Classifier with changable value for k.
int label = 0, bestDistance = Integer.MAX_VALUE;
int[][] distances = new int[trainingSet.size()][2];
int i=0;
// Place distances in an array to be sorted
for(ImageMatrix matrix: trainingSet) {
distances[i][0] = distance(matrix.getData(), curData);
distances[i][1] = matrix.getClassCode();
i++;
}
Arrays.sort(distances, (int[] lhs, int[] rhs) -> lhs[0]-rhs[0]);
// Find frequencies of each class code
i = 0;
Map<Integer,Integer> majorityMap;
majorityMap = new HashMap<Integer,Integer>();
while(i < k) {
if( majorityMap.containsKey( distances[i][1] ) ) {
int currentValue = majorityMap.get(distances[i][1]);
majorityMap.put(distances[i][1], currentValue + 1);
}
else {
majorityMap.put(distances[i][1], 1);
}
++i;
}
// Find the class code with the highest frequency
int maxVal = -1;
for (Entry<Integer, Integer> entry: majorityMap.entrySet()) {
int entryVal = entry.getValue();
if(entryVal > maxVal) {
maxVal = entryVal;
label = entry.getKey();
}
}
return label;
}
public int size() {
return list.size(); //returns size of the list
}
public static void main(String[] argv) throws IOException {
ImageMatrixDB trainingSet = new ImageMatrixDB();
ImageMatrixDB validationSet = new ImageMatrixDB();
Scanner scanner = new Scanner(System.in);
trainingSet.load("cw2DataSet1.csv");
validationSet.load("cw2DataSet2.csv");
int numCorrect = 0;
System.out.println("Enter the value of k:");
int k;
k = Integer.parseInt(scanner.nextLine());
for(ImageMatrix matrix:validationSet) {
if(classify(trainingSet, matrix.getData(), k) == matrix.getClassCode()) numCorrect++;
}
System.out.println("kNN Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%"); // Output to readable accuracy in %
System.out.println();
}
//////////////////////////////////////////
// Previous working dataset Load //
/* public static void main(String[] args){
ImageMatrixDB i = new ImageMatrixDB();
try{
i.load("cw2DataSet1.csv");
i.printResults();
}
catch(Exception ex){
ex.printStackTrace();
}
} */
}
1 Answer 1
Don't use assert
to check input data.
For this purpose it is better to throw an exception. It is a good practice to use exceptions when invalid data is coming from outside and assertions if invalid data is generated inside class or function. In other words, use assertions only if their condition has to be always true.
Think about good variable names
For example k
is a bad name. Try to describe in name what does the variable store.
I also don't like name list
. List of what? I assume a list of matrices. So why don't call it matrices
or images
? That kind of renaming has one another advantage - name will be valid after replacing list with another data structure: queue, stack etc.
Create smaller classes
Don't mix classification and CSV reading in one class. Moreover, make another class with main
function.