The following code is from a university assignment of mine to write a classification algorithm (using nearest neighbour) to classify whether or not a given feature set (each feature is the frequency of words in an email) is spam or not.
We are given a CSV file (the training data) with the frequencies, and an integer (1 or 0) at the end of the row indicating whether or not it is spam. So in this form:
X1,X2,...,Xn,SPAM
The test data is also in this form (including the SPAM column, so we can verify the accuracy).
My question is, how can I make this code more idiomatic, and what speedups can I make to this code? For example, is there a way to write getMostCommon
, without having to groupBy
and then run a maximumBy
again?
import Text.CSV
import Data.List
data SpamType = Spam | NotSpam
deriving (Show, Eq)
type FeatureSet = [Float]
toSpam :: Int -> SpamType
toSpam 0 = NotSpam
toSpam 1 = Spam
toSpam a = NotSpam
parseClassifiedRecord :: Record -> (FeatureSet, SpamType)
parseClassifiedRecord x = (init converted, toSpam (truncate (last converted)))
where
converted = map (\val -> read val :: Float) x
-- returns the euclidean distance between two feature sets
difference :: FeatureSet -> FeatureSet -> Float
difference first second = sqrt (sum (zipWith (\x y -> (x - y)^2) first second))
-- finds the SpamType of the k nearest 'nodes' in the training set
findKNearest :: [(FeatureSet, SpamType)] -> FeatureSet -> Int -> [SpamType]
findKNearest trainingSet toMatch k = take k (map snd (sortBy (\x y -> (compare (fst x) (fst y))) [(difference (fst x) toMatch, snd x) | x <- trainingSet]))
-- returns item which occurs most often in the list
getMostCommon :: (Eq a) => [a] -> a
getMostCommon list = head (maximumBy (\x y -> (compare (length x) (length y))) (groupBy (\x y -> (x == y)) list))
-- given a feature set, returns an ordered (i.e. same order as input)
-- list of whether or not feature is spam or not spam
-- looks at the closest k neighbours
classify :: [(FeatureSet, SpamType)] -> FeatureSet -> Int -> SpamType
classify trainingSet toClassify k = getMostCommon (findKNearest trainingSet toClassify k)
-- gives a value for the accuracy of expected vs actual for spam classification
-- i.e. num classified correctly / num total
accuracy :: [(SpamType, SpamType)] -> Float
accuracy classifications = (fromIntegral $ length (filter (\x -> (fst x) == (snd x)) classifications)) / (fromIntegral $ length classifications)
main = do
packed <- parseCSVFromFile "spam-train.csv"
packedtest <- parseCSVFromFile "spam-test.csv"
let (Right trainingSet) = packed
let (Right testSet) = packedtest
let classifiedTrainingSet = map parseClassifiedRecord (tail (init trainingSet))
let unclassified = map parseClassifiedRecord (tail (init testSet))
let classified = map (\x -> (classify classifiedTrainingSet (fst x) 1, snd x)) unclassified
putStrLn (show (accuracy classified))
3 Answers 3
- As previously noted, the
comparing
function is very useful for making compare functions. Especially if you are simply comparing by some field since you just have to give the accessor tocomparing
. Its a good idea to display the error message if there is one and for the
Either
type useeither
to handle bothLeft
andRight
cases.There is no need to make long variable names if they have a small scope and the variables don't have a nice succinct name, eg. with the
difference
function, usingfirst
andsecond
as parameter names just clutters the definition.It is common practice in Haskell to use function composition
.
and application$
rather than parentheses wherever applicable. It is simply more readable.putStrLn . show == print
groupBy (==) == group
There are quite a lot of combinators for working with functions over tupels, unfortunately (or fortunately depending on how you look at it) most of them are abstracted over
Control.Arrow
, I haven't included any of that in the code below, but(\x -> difference toMatch (fst x),snd x)
could be written asfirst (difference toMatch)
usingControl.Arrow.first
.Rather than limiting the use of
findKNearest
to thek
first elements of the list, make the function give all the elements and just consume as many as you need, saves you a parameter and makes the function more reusable.
Note: I left out comments only to make the changes more visible.
parseClassifiedRecord :: Record -> (FeatureSet, SpamType)
parseClassifiedRecord x = (init converted, toSpam . truncate . last $ converted)
where converted = map read x
difference a b = sqrt . sum . zipWith (\x y -> (x - y)^2) a $ b
findNearest :: [(FeatureSet, SpamType)] -> FeatureSet -> [SpamType]
findNearest trainingSet toMatch =
map snd
. sortBy (comparing fst)
. map (\(a,b) -> (difference a toMatch, b))
$ trainingSet
getMostCommon = head . maximumBy (comparing length) . group
classify :: Int -> [(FeatureSet, SpamType)] -> FeatureSet -> SpamType
classify k trainingSet = getMostCommon . take k . findNearest trainingSet
accuracy :: [(SpamType, SpamType)] -> Float
accuracy cs =
(fromIntegral
. length
. filter (uncurry (==))
$ cs) / (fromIntegral . length $ cs)
parseFile = liftM (either (error . show) (tail . init)) . parseCSVFromFile
main = do
trainingSet <- parseFile "spam-train.csv"
testSet <- parseFile "spam-test.csv"
let classifiedTrainingSet = map parseClassifiedRecord trainingSet
let unclassified = map parseClassifiedRecord testSet
let classified = map
(\(x,y) -> (classify 1 classifiedTrainingSet x, y))
unclassified
print (accuracy classified)
One helpful function is Data.Ord.comparing
:
comparing :: (Ord a) => (b -> a) -> b -> b -> Ordering
comparing p x y = compare (p x) (p y)
With this and using function composition you can write:
getMostCommon :: (Eq a) => [a] -> a
getMostCommon = head . maximumBy (comparing length) . groupBy (==)
Similar changes apply to several places. Another small simplification could be changing (\x -> (fst x) == (snd x))
to uncurry (==)
.
I liked HaskellElephant's answer, but since you mentioned performance, I put together a version using the killer cassava library + vector. It's a little less elegant, but runs 30x faster in my tests.
module CSVTest.New where
import Data.Csv
import Data.Vector
import Data.ByteString.Lazy (ByteString, readFile)
import Prelude hiding (tail, readFile, filter, take, init,
last, sum, zipWith, map, head, length,
foldl)
import Data.Vector.Algorithms.Heap
import Data.Ord
import Control.Monad
import Control.Monad.ST
data SpamType = Spam | NotSpam
deriving (Show, Eq, Enum, Bounded)
type FeatureSet = Vector Float
toSpam :: Int -> SpamType
toSpam 0 = NotSpam
toSpam 1 = Spam
toSpam a = NotSpam
parseClassifiedRecord :: Vector Float -> (FeatureSet, SpamType)
parseClassifiedRecord x = (init x, toSpam . truncate . last $ x)
difference :: FeatureSet -> FeatureSet -> Float
difference a b = sum . zipWith (\x y -> (x - y)^2) a $ b
findNearest :: Vector (FeatureSet, SpamType) -> FeatureSet -> Vector SpamType
findNearest trainingSet toMatch = result where
v = map (\x -> (difference (fst x) toMatch, snd x)) trainingSet
v' = runST $ do
mv <- thaw v
sortBy (comparing fst) mv
freeze mv
result = map snd v
getMostCommon :: Vector SpamType -> SpamType
getMostCommon v = result where
(spamCount, notSpamCount) = foldl (\(sc, nsc) x -> if x == Spam
then (sc + 1, nsc)
else (sc, nsc + 1)) (0,0) v
result = if spamCount >= notSpamCount
then Spam
else NotSpam
classify :: Int -> Vector (FeatureSet, SpamType) -> FeatureSet -> SpamType
classify k trainingSet = getMostCommon . take k . findNearest trainingSet
accuracy :: Vector (SpamType, SpamType) -> Float
accuracy cs =
(fromIntegral . length . filter (uncurry (==)) $ cs)
/ (fromIntegral . length $ cs)
parseCSVFromFile :: FilePath -> IO (Either String (Vector FeatureSet))
parseCSVFromFile = fmap decode . readFile
parseFile = liftM (either (error . show) (tail . init)) . parseCSVFromFile
run trainingFile testFile = do
trainingSet <- parseFile trainingFile
testSet <- parseFile testFile
let classifiedTrainingSet = map parseClassifiedRecord trainingSet
let unclassified = map parseClassifiedRecord testSet
let classified = map (\x -> (classify 1 classifiedTrainingSet (fst x), snd x))
unclassified
print (accuracy classified)
Explore related questions
See similar questions with these tags.
FromRecord
instance for unboxed vectors to cassava. \$\endgroup\$