(See the next iteration.)
I have this funky class method for computing standard deviation from an array of Number
objects using the Stream API:
StandardDeviation.java:
package net.coderodde.util;
import java.util.Arrays;
public class StandardDeviation {
public static double computeStandardDeviation(Number... collection) {
if (collection.length == 0) {
return Double.NaN;
}
final double average =
Arrays.stream(collection)
.mapToDouble((x) -> x.doubleValue())
.summaryStatistics()
.getAverage();
final double rawSum =
Arrays.stream(collection)
.mapToDouble((x) -> Math.pow(x.doubleValue() - average,
2.0))
.sum();
return Math.sqrt(rawSum / (collection.length - 1));
}
public static void main(String[] args) {
// Mix 'em all!
double sd = computeStandardDeviation((byte) 1,
(short) 2,
3,
4L,
5.0f,
6.0);
System.out.println(sd);
}
}
Please, tell me anything that comes to mind.
2 Answers 2
That's a nice and clean, short piece of code. :) It also appears to calculate the result correctly. My observations will follow, but really, only the first one is actually an issue:
- The application will crash with a
NullPointerException
if you try to compute the standard deviation fornull
. In addition to checkingcollection.length == 0
, you should first check forcollection == null
. - You should be consistent with your use of the
final
modifier. Finalize thatNumber... collection
,String[] args
anddouble sd
too. - While at it, you could finalize the entire class as well and add an empty private constructor:
private StandardDeviation() { /* prevent instantiation */ }
. It's an utility class with only static methods, so you don't need anyone callingnew StandardDeviation()
as it currently stands. Some static code analyzers would actually point this out as a minor issue for you. - Personally I find the use of method references nice and clear, so you could replace
.mapToDouble((x) -> x.doubleValue())
with.mapToDouble(Number::doubleValue)
- I find that
.mapToDouble((x) -> Math.pow(x.doubleValue() - average, 2.0))
does too much at once, you could also do it like this instead:.mapToDouble(Number::doubleValue).map(x -> x - average).map(StandardDeviation::square)
and then you'd have a new method like this in your class:private static Double square(final Double value) { return Math.pow(value, 2.0); }
. At least the Math.pow(...) part is much easier to read as a method reference. - As an external user of your utility class I'd like the option to call the standard deviation method with a
Collection
instead of varargs parameters. You'd then have to make one of the methods to call the other, first converting either the array into a collection or the other way around. Personally I prefer working with Collections.
- You are traversing the collection twice to determine the standard deviation when you could do it in a single pass.
- Also, you could accumulate quickly rounding errors with the
Math.pow(x.doubleValue() - average, 2.0)
call. It would be best to implement the Kahan summation algorithm (that the Stream API has forDoubleStream#sum()
). - In the lambda expression
(x) -> x.doubleValue()
, you don't need to add the parentheses around(x)
. You can just havex -> x.doubleValue()
. You could also use a method-reference, which avoids a lamda, and haveNumber::doubleValue
.
On Stack Overflow, I wrote an answer which calculates the standard deviation in a single pass with compensation. It is parallel-friendly:
static class DoubleStatistics extends DoubleSummaryStatistics {
private double sumOfSquare = 0.0d;
private double sumOfSquareCompensation; // Low order bits of sum
private double simpleSumOfSquare; // Used to compute right sum for
// non-finite inputs
@Override
public void accept(double value) {
super.accept(value);
double squareValue = value * value;
simpleSumOfSquare += squareValue;
sumOfSquareWithCompensation(squareValue);
}
public DoubleStatistics combine(DoubleStatistics other) {
super.combine(other);
simpleSumOfSquare += other.simpleSumOfSquare;
sumOfSquareWithCompensation(other.sumOfSquare);
sumOfSquareWithCompensation(other.sumOfSquareCompensation);
return this;
}
private void sumOfSquareWithCompensation(double value) {
double tmp = value - sumOfSquareCompensation;
double velvel = sumOfSquare + tmp; // Little wolf of rounding error
sumOfSquareCompensation = (velvel - sumOfSquare) - tmp;
sumOfSquare = velvel;
}
public double getSumOfSquare() {
double tmp = sumOfSquare + sumOfSquareCompensation;
if (Double.isNaN(tmp) && Double.isInfinite(simpleSumOfSquare)) {
return simpleSumOfSquare;
}
return tmp;
}
public final double getStandardDeviation() {
long count = getCount();
double sumOfSquare = getSumOfSquare();
double average = getAverage();
return count > 0 ? Math.sqrt((sumOfSquare - count * Math.pow(average, 2)) / (count - 1)) : 0.0d;
}
public static Collector<Double, ?, DoubleStatistics> collector() {
return Collector.of(DoubleStatistics::new, DoubleStatistics::accept, DoubleStatistics::combine);
}
}
It has the same logic as DoubleSummaryStatistics
but extended to calculate the sum of squares.
With such a class, you can then have:
public static double computeStandardDeviation(Number... collection) {
return Arrays.stream(collection)
.map(Number::doubleValue)
.collect(DoubleStatistics.collector())
.getStandardDeviation();
}
-
\$\begingroup\$ You are very knowledgeable about Java's APIs. \$\endgroup\$coderodde– coderodde2018年04月29日 08:41:03 +00:00Commented Apr 29, 2018 at 8:41
-
\$\begingroup\$ This version calculates the standard deviation as count > 0 ? Math.sqrt((sumOfSquare - count * Math.pow(average, 2)) / (count - 1)) : 0.0d; but the version you link to on StackOverflows has getCount() > 0 ? Math.sqrt((getSumOfSquare() / getCount()) - Math.pow(getAverage(), 2)) : 0.0d; Are they equivalent? \$\endgroup\$Damian Helme– Damian Helme2023年01月03日 12:15:38 +00:00Commented Jan 3, 2023 at 12:15