As an exercise to help me learn Haskell, I decided to write a small "program" to find the Pearson Correlation Coefficient for 2 lists of numbers. I'm pretty unhappy with how it turned out, because I feel like the typing is a mess, and it is difficult to read.
I was hoping that someone with experience would help me to improve my code, and tell me how they would have approached the problem.
import Data.List
summation :: (Integral a, Fractional b) => a -> a -> (a -> b) -> b
summation i n e = if (i < n)
then (e i + summation (i+1) n e)
else (e i)
mean :: (Real a, Fractional b) => [a] -> b
mean x = (1 / (genericLength x)) *
(summation 1 (length x) (\i -> realToFrac (x !! (i-1))))
covariance :: (Real a, Fractional b) => [a] -> [a] -> b
covariance x y = (1 / (genericLength x)) *
(summation 1 (length x) (\i -> ((realToFrac (x !! (i-1)) - mean x) * (realToFrac (y !! (i-1)) - mean x))))
stddev :: (Real a, Floating b) => [a] -> b
stddev x = ((1 / (genericLength x)) *
(summation 1 (length x) (\i -> (realToFrac (x !! (i-1)) - mean x) ** 2))) ** (1/2)
pearson :: (Real a, Floating b) => [a] -> [a] -> b
pearson x y = covariance x y / (stddev x * stddev y)
2 Answers 2
Each function on its own
Let us have a look at each of your functions first in solitude. What does summation
do?
summation :: (Integral a, Fractional b) => a -> a -> (a -> b) -> b summation i n e = if (i < n) then (e i + summation (i+1) n e) else (e i)
Well, for all numbers from x
in i
to n
, sum the e x
. We can write this a lot more clear with a list comprehension:
summation :: (Enum n, Num a) => n -> n -> a
summation i n e = sum [e x | x <- [i..n]]
Note that this only holds if i
is always lesser than n
in the inital call.
Next, what does mean
do?
mean :: (Real a, Fractional b) => [a] -> b mean x = (1 / (genericLength x)) * (summation 1 (length x) (\i -> realToFrac (x !! (i-1))))
Uh, that looks complicated. First, let us rewrite it as proper fraction:
mean x = num / denom
where
num = summation 1 (length x) (\i -> realToFrac (x !! (i-1)))
denom = genericLength x
Much easier to read. So, what do you actually do in num
? You calculate the sum of all numbers in x
. However, this can be done with sum
already. We end up with
mean :: Fractional a => [a] -> a
mean xs = sum xs / genericLength xs
Note that lists are usually named with a suffix s
. One x
, multiple xs
.
We use the same approach on covariance
. First, we rewrite it in simpler terms:
covariance x y = num / denom
where
num = summation 1 (length x) elemwise
denom = genericLength x
elemwise i = (realToFrac (x !! (i-1)) - mean x) * (realToFrac (y !! (i-1)) - mean x))
And now we immediately spot an error in elemwise
. You wrote y !! (i-1) - mean x
, but you meant y !! (i - 1) - mean y
.
However, let us have a look at the mathematical definition again:
$$ cov(X,Y) = E\left[(X-E[X])*(Y-E[Y])\right] $$
Allright. Let's write this exactly down as it stands there:
covariance x y = mean productXY
where
productXY = pairwiseProduct xSubMean ySubMean
xSubMean = subtractMean x
ySubMean = subtractMean y
This is obviously not ready yet. What are pairwiseProduct
and subtractMean
?
subtractMean :: Fractional a => [a] -> [a]
subtractMean xs = [x - mx | x <- xs]
where
mx = mean xs
pairwiseProduct :: Num a => [a] -> [a] -> [a]
pairwiseProduct xs ys = zipWith (*) xs ys
Now we're done. Note how the covariance
almost looks like pseudo-code? It's clearly easy to read by a human. That's the most important point you should take from this review: make your code easy to read for yourself.
stddev xs
is just sqrt (covariance xs xs)
, so you should probably use that:
-- sqrt needs Floating
stddev :: Floating a => [a] -> a
stddev xs = sqrt (covariance xs xs)
Everything at once
After we've finished our rewrite, it turns out you don't need summation
at all. So what do we end up with?
mean :: Fractional a => [a] -> a
mean xs = sum xs / genericLength xs
covariance :: Fractional a => [a] -> [a] -> a
covariance xs ys = mean productXY
where
productXY = zipWith (*) [x - mx | x <- xs] [y - my | y <- ys]
mx = mean xs
my = mean ys
stddev :: Floating a => [a] -> a
stddev xs = sqrt (covariance xs xs)
pearson :: (Floating a) => [a] -> [a] -> a
pearson x y = covariance x y / (stddev x * stddev y)
Note that covariance
got a lot easier to read due to bindings.
Please keep in mind that this implementation has some performance related issues. Especially the mean
implementation is a poster child for a function that leaks memory, but for small lists you should be fine.
Working with lists
When you work with lists, avoid element-wise access. It's really slow. Instead, you want to transform the whole map into a single value (a fold, e.g. length
or sum
), or into another map (above with list comprehensions, e.g. [x - mx | x <- xs]
).
import Data.List
mean :: (Floating a) => [a] -> a
mean x = sum / genericLength x
where sum = foldl (+) 0 x
covariance :: (Floating a) => [a] -> [a] -> a
covariance x y = (mean xy) - (mean x) * (mean y)
where xy = zipWith (*) x y
pearson :: (Floating a) => [a] -> [a] -> a
pearson x y = (covariance x y) / (stddev x * stddev y)
where stddev z = (covariance z z)**0.5
Learn to use Folds, Maps, Filters and Zips!
They are key concepts of functional programming.
The sum over a list x
can be written as foldl (+) 0 x
, a product would be foldl (*) 1 x
. This is not limited to basic math operations and numbers. You can supply any function on any type of elements.
In a similar manner, zipWith (*) x y
joins two lists into a list of products.
Don't repeat yourself.
Your function mean
, covariance
and stddev
share a lot of code. Never copy-paste code. Create the proper modularisation and reuse it. Here, covariance
can be reformulated to use only calls to mean
, and stddev
is just the self-covariance (and a square root).
-
\$\begingroup\$ Even better, use the proper functions already defined in the
Prelude
, namelysum
andproduct
. Don't rewrite those yourself. By the way, your (and the OP's)mean
implementation is the poster child of a memory leak. And covariance has to traverse each listx´ and
y` three times: once forsum
, once forlength
and once forzipWith
. Writingcovariance
in a memory efficient way isn't as "don't repeat yourself", but your implementation should be fine for small sample sizes. \$\endgroup\$Zeta– Zeta2017年02月18日 11:00:41 +00:00Commented Feb 18, 2017 at 11:00 -
1\$\begingroup\$ Also,
**(0.5)
is calledsqrt
, also in thePrelude
. \$\endgroup\$Zeta– Zeta2017年02月18日 11:03:31 +00:00Commented Feb 18, 2017 at 11:03 -
\$\begingroup\$ I like the end result you provide here. What's IMO missing is showing to OP where the problems in the original code are. Your review is "factually correct", but doesn't respect the language competence level of OP. So while the code is good, the review IMO isn't helpful to OP... \$\endgroup\$Vogel612– Vogel6122017年02月18日 12:07:14 +00:00Commented Feb 18, 2017 at 12:07