To get myself into functional programming, I implemented a simple nearest neighbor classifier using Haskell. The code works but is incredibly slow. Profiling tells me that most of the time is spent calculating diffRed
, diffGreen
and diffBlue
. That seems reasonable since there is not much else to do. But why does it take so long and how can I improve it? A quick test with NumPy showed a much better performance (approx. 10 times faster).
import System.Directory
import System.IO
import qualified Data.ByteString as ByteStr
import qualified Data.ByteString.Char8 as ByteStrCh8
import Data.Word
import Data.List
import qualified Data.Vector.Unboxed as Vec
data LabeledImage = LabeledImage {
labelIdx :: Int
, redPixels :: Vec.Vector Word8
, greenPixels :: Vec.Vector Word8
, bluePixels :: Vec.Vector Word8
} deriving (Eq)
instance Show LabeledImage where
show (LabeledImage label _ _ _) = "Image of type " ++ show label ++ "."
printEnumeratedLabels :: [String] -> Int -> IO ()
printEnumeratedLabels [] _ = return ()
printEnumeratedLabels (displayedString:trailingStrings) index = do
putStrLn $ "String " ++ show index ++ ": " ++ displayedString
printEnumeratedLabels trailingStrings (index + 1)
extractLabeledImages :: ByteStr.ByteString -> [LabeledImage] -> [LabeledImage]
extractLabeledImages source images
| ByteStr.length source >= imgLength =
let
(label,rbgData) = ByteStr.splitAt labelBytes source
(redData,bgData) = ByteStr.splitAt colorBytes rbgData
(greenData,bData) = ByteStr.splitAt colorBytes bgData
(blueData,trailData) = ByteStr.splitAt colorBytes bData
numLabel = fromIntegral (ByteStr.head label)
redValues = Vec.generate (ByteStr.length redData) (ByteStr.index redData)
greenValues = Vec.generate (ByteStr.length greenData) (ByteStr.index greenData)
blueValues = Vec.generate (ByteStr.length blueData) (ByteStr.index blueData)
in
extractLabeledImages trailData (images ++ [LabeledImage numLabel redValues greenValues blueValues])
| otherwise = images
where
labelBytes = 1
colorBytes = 1024
imgLength = labelBytes + 3 * colorBytes
calculateL1Distance :: LabeledImage -> LabeledImage -> Int
calculateL1Distance referenceImage testImage =
let
substractPixels :: Word8 -> Word8 -> Int
substractPixels a b = abs $ fromIntegral a - fromIntegral b
diffRed = Vec.zipWith substractPixels (redPixels referenceImage) (redPixels testImage)
diffGreen = Vec.zipWith substractPixels (greenPixels referenceImage) (greenPixels testImage)
diffBlue = Vec.zipWith substractPixels (bluePixels referenceImage) (bluePixels testImage)
in
fromIntegral $ Vec.sum diffRed + Vec.sum diffGreen + Vec.sum diffBlue
findMinimalDistanceImage :: (LabeledImage -> LabeledImage -> Int) -> [LabeledImage] -> LabeledImage -> Maybe LabeledImage
findMinimalDistanceImage distance referenceImages testImage =
let
distances = [(referenceImage, distance referenceImage testImage) | referenceImage <- referenceImages ]
absDistances = map snd distances
minimalDistance = minimum absDistances
minIndex = elemIndex minimalDistance absDistances
in
case minIndex of
Just index -> Just $ fst (distances !! index)
Nothing -> Nothing
checkMatch :: Maybe LabeledImage -> LabeledImage -> Maybe Bool
checkMatch Nothing _ = Nothing
checkMatch (Just referenceImage) testImage =
let
img = referenceImage
in
Just (labelIdx img == labelIdx testImage)
checkTrue :: Maybe Bool -> Bool
checkTrue value
| value == Just True = True
| otherwise = False
checkFalse :: Maybe Bool -> Bool
checkFalse value
| value == Just False = True
| otherwise = False
checkNothing :: Maybe Bool -> Bool
checkNothing Nothing = True
checkNothing _ = False
main = do
labelsStr <- ByteStr.readFile "M:\\Documents\\StanfordCNN\\cifar10\\batches.meta.txt"
let labels = lines $ ByteStrCh8.unpack labelsStr
printEnumeratedLabels labels 1
batch1Raw <- ByteStr.readFile "M:\\Documents\\StanfordCNN\\cifar10\\data_batch_1.bin"
let batch1 = extractLabeledImages batch1Raw []
putStrLn $ "Number of batch 1 images: " ++ show (length batch1)
batch2Raw <- ByteStr.readFile "M:\\Documents\\StanfordCNN\\cifar10\\data_batch_2.bin"
let batch2 = extractLabeledImages batch2Raw []
putStrLn $ "Number of batch 2 images: " ++ show (length batch2)
batch3Raw <- ByteStr.readFile "M:\\Documents\\StanfordCNN\\cifar10\\data_batch_3.bin"
let batch3 = extractLabeledImages batch3Raw []
putStrLn $ "Number of batch 3 images: " ++ show (length batch3)
batch4Raw <- ByteStr.readFile "M:\\Documents\\StanfordCNN\\cifar10\\data_batch_4.bin"
let batch4 = extractLabeledImages batch4Raw []
putStrLn $ "Number of batch 4 images: " ++ show (length batch4)
batch5Raw <- ByteStr.readFile "M:\\Documents\\StanfordCNN\\cifar10\\data_batch_5.bin"
let batch5 = extractLabeledImages batch5Raw []
putStrLn $ "Number of batch 5 images: " ++ show (length batch5)
testBatchRaw <- ByteStr.readFile "M:\\Documents\\StanfordCNN\\cifar10\\test_batch.bin"
let testBatch = extractLabeledImages testBatchRaw []
putStrLn $ "Number of test batch images: " ++ show (length testBatch)
let referenceImages = batch1 ++ batch2 ++ batch3 ++ batch4 ++ batch5
let testImages = testBatch
putStrLn "Created image sets. Starting tests."
let evaluateImage = checkMatch . findMinimalDistanceImage calculateL1Distance referenceImages
let results = [evaluateImage testImage testImage | testImage <- testImages ]
putStrLn $ "Results: Match:" ++ show (length (filter checkTrue results))
++ " Fail:" ++ show (length (filter checkFalse results))
++ " Error:" ++ show (length (filter checkNothing results))
Here are the stats from ghc:
INIT time 0.000s ( 0.017s elapsed)
MUT time 28954.453s (29156.816s elapsed)
GC time 554.672s (627.758s elapsed)
EXIT time 0.000s ( 0.133s elapsed)
Total time 29509.125s (29784.724s elapsed)
%GC time 1.9% (2.1% elapsed)
Alloc rate 3,402,660,550 bytes per MUT second
Productivity 98.1% of total user, 97.9% of total elapsed
I compiled with
ghc -O -fforce-recomp -rtsopts -o test .\compare_images.hs
and started the program with
.\test.exe +RTS -sstderr
The program takes every image from the CIFAR-10 test images and compares it with all reference images using L1 nearest neighbor classifier. The image is then classified to belong to one of the 10 classes based based on the class of the nearest neighbor. Finally for each test image the correct class and the determined class are compared and the number of correct guesses and wrong guesses are counted.
-
1\$\begingroup\$ You... left that program running for ~8 hours to get a profile? That's some dedication. Just for completion, could you add your used flags? Also, a little bit more detail on what your code does would be great. \$\endgroup\$Zeta– Zeta2017年05月04日 08:56:37 +00:00Commented May 4, 2017 at 8:56
-
\$\begingroup\$ Well, I noticed that it is slow. I just wanted to know how slow it is. Since it only runs on one core it did not hinder me in doing other work. As it seems, the program causes only 15% CPU load (according to task manager). I added compile flags and runtime parameters to the post. \$\endgroup\$Oliver Gerlach– Oliver Gerlach2017年05月04日 09:47:02 +00:00Commented May 4, 2017 at 9:47
-
\$\begingroup\$ @200_success: Why time-limit-exceeded? It's not a programming challenge, and there is no actual time limit, the code is just slow compared to NumPy (and CIFAR-10 test images are being used). Shouldn't it be tagged with performance? \$\endgroup\$Zeta– Zeta2017年05月04日 12:58:34 +00:00Commented May 4, 2017 at 12:58
-
\$\begingroup\$ @Zeta I suppose it's a matter of opinion, and could go either way. Two criteria that I usually apply are: it's so slow that it's "unacceptable", and the fix is likely to require a new algorithm. Change it back if you prefer. \$\endgroup\$200_success– 200_success2017年05月04日 13:27:58 +00:00Commented May 4, 2017 at 13:27
1 Answer 1
As it turns out, a few optimizations drastically improved performance:
- Convert pixel data to Int during loading
- enable -O2 optimizations
- optimize calculateL1Distance
The optimized code for calculateL1Distance is:
calculateL1Distance :: LabeledImage -> LabeledImage -> Int
calculateL1Distance reference test =
let
substractPixels :: Int -> Int -> Int
substractPixels a b = abs $ a - b
diff f = Vec.sum $ Vec.zipWith substractPixels (f reference) (f test)
in
diff redPixels + diff greenPixels + diff bluePixels
This code is not only more pleasent to read. I assume it allows for more aggressive optimizations. At least it cuts down runtime to 5205.797s. This is comparable to NumPy and seems acceptable for this kind of algorithm.
Explore related questions
See similar questions with these tags.