4
\$\begingroup\$
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.util.Scanner;
/**
 * Solution for Quora Classifier question from CodeSprint 2012 This class
 * implements a logistic regression classifier trained with Stochastic Gradient
 * Descent. No regularization is used as it wasn't necessary for this problem
 */
public class QuoraClassifier {
 static boolean debug;
 double[] theta; // regression coefficients
 double[] yTraining; // training targets
 double[][] xTraining; // training predictors
 double[] xMean; // mean of each training predictor
 double[] xVarSqrt; // variance of each training predictor
 double alpha = 5; // training rate
 int numTrainingExamples; // number of training examples
 int numFeatures; // number of features in predictor
 /**
 * Load training data
 *
 * @param sc
 */
 private void loadTrainingData(Scanner sc) {
 String[] trainingDimString = sc.nextLine().split("\\s");
 numTrainingExamples = Integer.parseInt(trainingDimString[0]);
 numFeatures = Integer.parseInt(trainingDimString[1]);
 yTraining = new double[numTrainingExamples];
 xTraining = new double[numTrainingExamples][numFeatures];
 for (int i = 0; i < numTrainingExamples; i++) {
 String[] trainingPoint = sc.nextLine().split("\\s");
 // read training targets and convert from -1/+1 to 0/1
 yTraining[i] = (Double.parseDouble(trainingPoint[1]) + 1) / 2;
 // read training predictors
 for (int fIndex = 0; fIndex < numFeatures; fIndex++) {
 String featureString = trainingPoint[2 + fIndex];
 xTraining[i][fIndex] = Double.parseDouble(featureString
 .substring(featureString.indexOf(":") + 1));
 }
 }
 }
 /**
 * Normalize training data by mean and variance
 */
 private void normalizeTrainingData() {
 // calculate mean of each feature
 xMean = new double[numFeatures];
 for (int fIndex = 0; fIndex < numFeatures; fIndex++) {
 double runningSum = 0.0;
 for (int i = 0; i < numTrainingExamples; i++) {
 runningSum += xTraining[i][fIndex];
 }
 xMean[fIndex] = runningSum / numTrainingExamples;
 }
 // normalize by feature means
 for (int fIndex = 0; fIndex < numFeatures; fIndex++) {
 for (int i = 0; i < numTrainingExamples; i++) {
 xTraining[i][fIndex] -= xMean[fIndex];
 }
 }
 // calculate variance of each feature
 xVarSqrt = new double[numFeatures];
 for (int fIndex = 0; fIndex < numFeatures; fIndex++) {
 double runningSum = 0.0;
 for (int i = 0; i < numTrainingExamples; i++) {
 runningSum += xTraining[i][fIndex] * xTraining[i][fIndex];
 }
 if (runningSum > 0.0) {
 xVarSqrt[fIndex] = Math.sqrt(runningSum / numTrainingExamples);
 } else {
 xVarSqrt[fIndex] = 1.0;
 }
 }
 // normalize by feature variances
 for (int fIndex = 0; fIndex < numFeatures; fIndex++) {
 for (int i = 0; i < numTrainingExamples; i++) {
 xTraining[i][fIndex] /= xVarSqrt[fIndex];
 }
 }
 }
 /**
 * Train logistic regression coefficients
 */
 private void trainLogistic() {
 theta = new double[numFeatures];
 for (int i = 0; i < numTrainingExamples; i++) {
 double yEstimate = classify(xTraining[i]);
 // calculate error in prediction
 double e = yEstimate - yTraining[i];
 for (int fIndex = 0; fIndex < numFeatures; fIndex++) {
 // update regression coefficient
 theta[fIndex] -= alpha / numTrainingExamples * e * xTraining[i][fIndex];
 }
 }
 }
 /**
 * Classify feature vector x
 *
 * @param x
 * array of doubles containing the features
 * @return double containing the soft classification
 */
 private double classify(double[] x) {
 double z = 0;
 for (int fIndex = 0; fIndex < numFeatures; fIndex++) {
 double xNormalized = (x[fIndex] - xMean[fIndex]) / xVarSqrt[fIndex];
 z += theta[fIndex] * xNormalized;
 }
 return sigmoid(z);
 }
 /**
 * Run classification for a set of queries and output the result to System.out
 *
 * @param scIn
 * input data (contains test and training sets)
 * @param scOut
 * training targets for validation
 */
 private void classifyQueries(Scanner scIn, Scanner scOut) {
 int numQueries = Integer.parseInt(scIn.nextLine());
 double[] xQuery = new double[numFeatures];
 double numMisclassifications = 0;
 for (int i = 0; i < numQueries; i++) {
 // load query
 String[] queryString = scIn.nextLine().split("\\s");
 String label = queryString[0];
 for (int fIndex = 0; fIndex < numFeatures; fIndex++) {
 String featureString = queryString[1 + fIndex];
 xQuery[fIndex] = Double.parseDouble(featureString
 .substring(featureString.indexOf(":") + 1));
 }
 // classify
 double classification = classify(xQuery);
 // output result to System.out
 if (classification > 0.5) {
 System.out.printf("%s +1\n", label);
 } else {
 System.out.printf("%s -1\n", label);
 }
 if (debug) {
 // read training targets
 String[] outString = scOut.nextLine().split("\\s");
 String correctClassification = outString[1];
 // display misclassifications
 if (classification > 0.5) {
 if (correctClassification.equals("-1")) {
 numMisclassifications++;
 System.out.printf("%s +1 %s\n", label, correctClassification);
 }
 } else {
 if (correctClassification.equals("+1")) {
 numMisclassifications++;
 System.out.printf("%s -1 %s\n", label, correctClassification);
 }
 }
 }
 }
 if (debug) {
 System.out.printf("Misclassification rate: %.1f%%\n",
 (100.0 * numMisclassifications) / numQueries);
 }
 }
 /**
 * Sigmoid function
 *
 * @param z
 * a double
 * @return double containing 1/(1 + exp(-z))
 */
 private double sigmoid(double z) {
 return 1.0 / (1.0 + Math.exp(-z));
 }
 /**
 * Run classification using input data from scanner scIn and validate against
 * training targets in scanner scOut
 *
 * @param scIn
 * input data (contains test and training sets)
 * @param scOut
 * training targets for validation
 */
 public void run(Scanner scIn, Scanner scOut) {
 loadTrainingData(scIn);
 normalizeTrainingData();
 trainLogistic();
 classifyQueries(scIn, scOut);
 }
 /**
 * Run classification and profile execution time. If no arguments are supplied
 * input is read from standard input and output is written to standard output
 * If two arguments are supplied then input is read from the file specified by
 * the first argument and output is written to the file specified by the
 * second argument
 *
 */
 public static void main(String[] args) throws FileNotFoundException {
 Scanner scIn; // input file
 Scanner scOut; // output file
 if (args.length > 0) { // input stream from file (for testing)
 BufferedReader in = new BufferedReader(new FileReader(new File(args[0])));
 scIn = new Scanner(in);
 BufferedReader out = new BufferedReader(new FileReader(new File(args[1])));
 scOut = new Scanner(out);
 debug = true;
 } else { // input streamed from System.in (used in competition)
 scIn = new Scanner(System.in);
 scOut = null;
 debug = false;
 }
 long startTime = 0;
 if (debug) {
 startTime = System.nanoTime();
 }
 // run classification
 QuoraClassifier classifier = new QuoraClassifier();
 classifier.run(scIn, scOut);
 if (debug) {
 long endTime = System.nanoTime();
 System.out.printf("Execution time: %f\n",
 ((double) endTime - (double) startTime) / 1e9);
 }
 }
}
asked Jan 14, 2012 at 20:35
\$\endgroup\$

2 Answers 2

3
\$\begingroup\$

Just off the top of my head:

  • In main - what should happen if args.length == 1 ? Currently, you'll get an exception.
  • Your field names are a bit odd - xTraining and yTraining in particular. Instead of having these names, and comments about what the variables actually mean, why not just call your variables regressionCoefficients, trainingTargets, trainingPredictors and so on?
  • Some of the methods are a bit long. The fact that your methods mostly have internal comments kind of indicates this. In my opinion, each method should do just one thing; and that one thing should be explained in a javadoc comment at the top of the method. No comments at all inside methods.
  • Comments should be used to explain things, not just to repeat what's in the code. For example, your "return" comment at the top of sigmoid is pointless. I suggest removing it.
  • Do you really want to use System.out.printf for logging? There are more versatile ways of logging (which you could google), and more readable means of formatting.
answered Jan 16, 2012 at 4:16
\$\endgroup\$
2
\$\begingroup\$

Same answer as your other question.

You should have some class to load (and contain) the original data; some Normalizer interface (you might have many different classes that implement different normalizations); some Classifier interface; you are training something, so that something should be a class with a train(?,...) method. And then again, there are probably many ways to train it, so there should be some interface for that too.

answered Jan 15, 2012 at 15:59
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.