I've been trying to get better at Haskell for a while, and have recently been working on a lot of small projects with it. This constructs a binary decision tree.
The command to run it is:
stack exec decision-tree-exe <threshold> <training file> <testing file>
where threshold is in the range (0,1].
I think I've gotten a lot better, but I'm still having problems, especially with performance and readability. For this project, I took a more top down approach, implementing functions after using them. src/DecisionTree.hs is where the bulk of the logic is, and the file is pretty much in order of writing. I would love to get some feedback from some more experienced people on where I might improve.
module DecisionTree where
import Data.List (genericLength, maximumBy, nub)
import Data.Map (elemAt, foldlWithKey', fromListWith)
import Data.Ord
data DecisionTree a b
= Node ([a] -> Bool) (DecisionTree a b) (DecisionTree a b)
| Leaf b
type Dataset cat attrs = [(cat, [attrs])]
type Threshold = Double
type Splitter c a = ([a] -> Bool, Dataset c a, Dataset c a)
apply :: DecisionTree a b -> [a] -> b
apply (Leaf b) _ = b
apply (Node f l r) a =
case f a of
False -> apply l a
True -> apply r a
train ::
(Ord c)
=> (Dataset c a -> Maybe (Splitter c a))
-> Dataset c a
-> DecisionTree a c
train splitter dataset =
case splitter dataset of
Just (partitioner, left, right) ->
Node partitioner (train splitter left) (train splitter right)
Nothing -> Leaf majority
where
classCounts = fromListWith (+) $ map (\k -> (fst k, 1)) dataset
majority = fst $ foldlWithKey' max (elemAt 0 classCounts) classCounts
max acc k v
| v > snd acc = (k, v)
| otherwise = acc
giniSplitter ::
(Ord a, Ord c) => Threshold -> Dataset c a -> Maybe (Splitter c a)
giniSplitter threshold dataset =
case fst maxDelta > threshold of
True -> Just $ snd maxDelta
False -> Nothing
where
attrs = nub . concat . snd . unzip $ dataset
partitioner a = (a `elem`)
delta a = giniDelta (partitioner a) dataset
maxDelta = maximumBy (comparing fst) $ map delta attrs
giniDelta :: (Eq c) => ([a] -> Bool) -> Dataset c a -> (Double, Splitter c a)
giniDelta partitioner dataset =
( gini dataset - (d1 / d * gini left + d2 / d * gini right)
, (partitioner, left, right))
where
left = filter (not . partitioner . snd) dataset
right = filter (partitioner . snd) dataset
d1 = genericLength left
d2 = genericLength right
d = genericLength dataset
gini :: (Eq c) => Dataset c a -> Double
gini d = 1 - sum [(pj c) ** 2 | c <- nub . fst . unzip $ d]
where
pj c = genericLength (filter ((== c) . fst) d) / genericLength d
1 Answer 1
Just a few random comments:
Both
elemAt
andmaximumBy
give hints that you're expecting to operate on non-empty structures. Maybe giveData.List.NonEmpty
a try.A few places could be more clear with more pattern matching. E.g.
max (k1, v1) k2 v2
instead ofmax acc k v
. Or(maxDelta, splitter) = maximumBy ...
map snd
is more conventional thansnd . unzip
. I suspect it would be more efficient too but I might be wrong.In several places you're traversing the same list multiple times. In general, it's better to avoid this as it's likely to force the spine of the (potentially large) list in memory. You might be able to merge these multiple traversals into one (e.g. using the
foldl
package). More likely, you should simply usevector
.In
giniDelta
you could useData.List.partition
to constructleft
andright
.Apply top-down ordering in your
where
-clauses. E.g. intrain
,majority
should come first as it is the declaration that is referenced from the main function body.
EDIT: All in all I think readability is actually pretty good!