5
\$\begingroup\$

The following code is from a university assignment of mine to write a classification algorithm (using nearest neighbour) to classify whether or not a given feature set (each feature is the frequency of words in an email) is spam or not.

We are given a CSV file (the training data) with the frequencies, and an integer (1 or 0) at the end of the row indicating whether or not it is spam. So in this form:

X1,X2,...,Xn,SPAM

The test data is also in this form (including the SPAM column, so we can verify the accuracy).

My question is, how can I make this code more idiomatic, and what speedups can I make to this code? For example, is there a way to write getMostCommon, without having to groupBy and then run a maximumBy again?

import Text.CSV
import Data.List
data SpamType = Spam | NotSpam
 deriving (Show, Eq)
type FeatureSet = [Float]
toSpam :: Int -> SpamType
toSpam 0 = NotSpam
toSpam 1 = Spam 
toSpam a = NotSpam
parseClassifiedRecord :: Record -> (FeatureSet, SpamType)
parseClassifiedRecord x = (init converted, toSpam (truncate (last converted)))
 where 
 converted = map (\val -> read val :: Float) x
-- returns the euclidean distance between two feature sets
difference :: FeatureSet -> FeatureSet -> Float
difference first second = sqrt (sum (zipWith (\x y -> (x - y)^2) first second))
-- finds the SpamType of the k nearest 'nodes' in the training set
findKNearest :: [(FeatureSet, SpamType)] -> FeatureSet -> Int -> [SpamType]
findKNearest trainingSet toMatch k = take k (map snd (sortBy (\x y -> (compare (fst x) (fst y))) [(difference (fst x) toMatch, snd x) | x <- trainingSet]))
-- returns item which occurs most often in the list
getMostCommon :: (Eq a) => [a] -> a
getMostCommon list = head (maximumBy (\x y -> (compare (length x) (length y))) (groupBy (\x y -> (x == y)) list))
-- given a feature set, returns an ordered (i.e. same order as input)
-- list of whether or not feature is spam or not spam
-- looks at the closest k neighbours
classify :: [(FeatureSet, SpamType)] -> FeatureSet -> Int -> SpamType
classify trainingSet toClassify k = getMostCommon (findKNearest trainingSet toClassify k)
-- gives a value for the accuracy of expected vs actual for spam classification
-- i.e. num classified correctly / num total
accuracy :: [(SpamType, SpamType)] -> Float
accuracy classifications = (fromIntegral $ length (filter (\x -> (fst x) == (snd x)) classifications)) / (fromIntegral $ length classifications)
main = do 
 packed <- parseCSVFromFile "spam-train.csv"
 packedtest <- parseCSVFromFile "spam-test.csv"
 let (Right trainingSet) = packed
 let (Right testSet) = packedtest
 let classifiedTrainingSet = map parseClassifiedRecord (tail (init trainingSet))
 let unclassified = map parseClassifiedRecord (tail (init testSet))
 let classified = map (\x -> (classify classifiedTrainingSet (fst x) 1, snd x)) unclassified
 putStrLn (show (accuracy classified))
Jamal
35.2k13 gold badges134 silver badges238 bronze badges
asked Sep 19, 2012 at 4:05
\$\endgroup\$
2
  • \$\begingroup\$ small suggestion: don't use sqrt. Comparing squared distances won't change the answer. \$\endgroup\$ Commented Oct 8, 2012 at 0:27
  • \$\begingroup\$ Using unboxed vectors should be even faster. I will look into adding a FromRecord instance for unboxed vectors to cassava. \$\endgroup\$ Commented Oct 10, 2012 at 17:34

3 Answers 3

4
\$\begingroup\$
  • As previously noted, the comparing function is very useful for making compare functions. Especially if you are simply comparing by some field since you just have to give the accessor to comparing.
  • Its a good idea to display the error message if there is one and for the Either type use either to handle both Left and Right cases.

  • There is no need to make long variable names if they have a small scope and the variables don't have a nice succinct name, eg. with the difference function, using first and second as parameter names just clutters the definition.

  • It is common practice in Haskell to use function composition . and application $ rather than parentheses wherever applicable. It is simply more readable.

  • putStrLn . show == print

  • groupBy (==) == group

  • There are quite a lot of combinators for working with functions over tupels, unfortunately (or fortunately depending on how you look at it) most of them are abstracted over Control.Arrow, I haven't included any of that in the code below, but (\x -> difference toMatch (fst x),snd x) could be written as first (difference toMatch) using Control.Arrow.first.

  • Rather than limiting the use of findKNearest to the k first elements of the list, make the function give all the elements and just consume as many as you need, saves you a parameter and makes the function more reusable.

Note: I left out comments only to make the changes more visible.

parseClassifiedRecord :: Record -> (FeatureSet, SpamType)
parseClassifiedRecord x = (init converted, toSpam . truncate . last $ converted)
 where converted = map read x
difference a b = sqrt . sum . zipWith (\x y -> (x - y)^2) a $ b
findNearest :: [(FeatureSet, SpamType)] -> FeatureSet -> [SpamType]
findNearest trainingSet toMatch =
 map snd
 . sortBy (comparing fst)
 . map (\(a,b) -> (difference a toMatch, b))
 $ trainingSet
getMostCommon = head . maximumBy (comparing length) . group
classify :: Int -> [(FeatureSet, SpamType)] -> FeatureSet -> SpamType
classify k trainingSet = getMostCommon . take k . findNearest trainingSet
accuracy :: [(SpamType, SpamType)] -> Float
accuracy cs = 
 (fromIntegral 
 . length 
 . filter (uncurry (==)) 
 $ cs) / (fromIntegral . length $ cs)
parseFile = liftM (either (error . show) (tail . init)) . parseCSVFromFile
main = do 
 trainingSet <- parseFile "spam-train.csv"
 testSet <- parseFile "spam-test.csv"
 let classifiedTrainingSet = map parseClassifiedRecord trainingSet
 let unclassified = map parseClassifiedRecord testSet
 let classified = map
 (\(x,y) -> (classify 1 classifiedTrainingSet x, y))
 unclassified
 print (accuracy classified)
answered Oct 7, 2012 at 23:12
\$\endgroup\$
4
\$\begingroup\$

One helpful function is Data.Ord.comparing:

comparing :: (Ord a) => (b -> a) -> b -> b -> Ordering
comparing p x y = compare (p x) (p y)

With this and using function composition you can write:

getMostCommon :: (Eq a) => [a] -> a
getMostCommon = head . maximumBy (comparing length) . groupBy (==)

Similar changes apply to several places. Another small simplification could be changing (\x -> (fst x) == (snd x)) to uncurry (==).

answered Sep 19, 2012 at 17:45
\$\endgroup\$
3
\$\begingroup\$

I liked HaskellElephant's answer, but since you mentioned performance, I put together a version using the killer cassava library + vector. It's a little less elegant, but runs 30x faster in my tests.

module CSVTest.New where
import Data.Csv
import Data.Vector
import Data.ByteString.Lazy (ByteString, readFile)
import Prelude hiding (tail, readFile, filter, take, init, 
 last, sum, zipWith, map, head, length,
 foldl)
import Data.Vector.Algorithms.Heap
import Data.Ord
import Control.Monad
import Control.Monad.ST
data SpamType = Spam | NotSpam
 deriving (Show, Eq, Enum, Bounded)
type FeatureSet = Vector Float 
toSpam :: Int -> SpamType
toSpam 0 = NotSpam
toSpam 1 = Spam 
toSpam a = NotSpam
parseClassifiedRecord :: Vector Float -> (FeatureSet, SpamType)
parseClassifiedRecord x = (init x, toSpam . truncate . last $ x)
difference :: FeatureSet -> FeatureSet -> Float
difference a b = sum . zipWith (\x y -> (x - y)^2) a $ b
findNearest :: Vector (FeatureSet, SpamType) -> FeatureSet -> Vector SpamType
findNearest trainingSet toMatch = result where
 v = map (\x -> (difference (fst x) toMatch, snd x)) trainingSet
 v' = runST $ do
 mv <- thaw v
 sortBy (comparing fst) mv
 freeze mv
 result = map snd v 
getMostCommon :: Vector SpamType -> SpamType
getMostCommon v = result where
 (spamCount, notSpamCount) = foldl (\(sc, nsc) x -> if x == Spam 
 then (sc + 1, nsc) 
 else (sc, nsc + 1)) (0,0) v
 result = if spamCount >= notSpamCount 
 then Spam
 else NotSpam
classify :: Int -> Vector (FeatureSet, SpamType) -> FeatureSet -> SpamType
classify k trainingSet = getMostCommon . take k . findNearest trainingSet
accuracy :: Vector (SpamType, SpamType) -> Float
accuracy cs = 
 (fromIntegral . length . filter (uncurry (==)) $ cs)
 / (fromIntegral . length $ cs)
parseCSVFromFile :: FilePath -> IO (Either String (Vector FeatureSet))
parseCSVFromFile = fmap decode . readFile 
parseFile = liftM (either (error . show) (tail . init)) . parseCSVFromFile
run trainingFile testFile = do 
 trainingSet <- parseFile trainingFile
 testSet <- parseFile testFile
 let classifiedTrainingSet = map parseClassifiedRecord trainingSet
 let unclassified = map parseClassifiedRecord testSet
 let classified = map (\x -> (classify 1 classifiedTrainingSet (fst x), snd x)) 
 unclassified
 print (accuracy classified) 
answered Oct 10, 2012 at 3:43
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.