I built a tree in Haskell
where every node has a number that is unique in its path from root to leaf. Nodes have a dedicated list for children that are leaves. Because the tree is very large I want to unify (i.e. keep the right one, replace the left one with a link) sub trees that are equal by some metric. For this and similar tasks I wrote the treewalker
function that walks through all nodes of a certain depth and applies a function that has access to already visited nodes via a generic cache.
treewalker
can then be called with several functions that specify how to process each node. In the example provided I link to already seen sub trees whose path is a permutation of the path to the current node. But I use several other functions with different purposes, so I want to keep the cache general and the function that is applied to each node as well.
{-# Language ScopedTypeVariables #-} -- needed for treewalker
data LinkedTree = LinkedNode Int -- index of node
[Int] -- leaf indices
[LinkedTree] -- children
| Link Int -- index of node
[Int] --path to the linked node
deriving (Show, Eq, Ord)
newtype Path = Path [Int]
newtype ListCache a = ListCache [a]
type PathCache = ListCache Path
class Cache a where
cacheAdd :: b -> a b -> a b
emptyCache :: a b
instance Cache ListCache where
cacheAdd p (ListCache x) = ListCache (p:x)
emptyCache = ListCache []
-- walks through all nodes of specified depth and applies a function
treeWalker :: forall c a. (Cache c) => (LinkedTree -> [Int] -> c a -> (LinkedTree, c a)) -- function to apply on every node of desired depth
-> Int --depth
-> LinkedTree
-> LinkedTree
treeWalker processNode desiredDepth lnode = fst $ helper [] emptyCache lnode
where
helper :: [Int] -- path up until now, excluding current node, [level k, ..., level 1, root]
-> c a
-> LinkedTree
-> (LinkedTree, c a)
helper _ cache (Link is p) = (Link is p, cache) -- there may be links already e.g. from a previous run with different parameters
helper path cache (LinkedNode is lis children) | tooShortNoChildren = (LinkedNode is lis [], cache) -- path ends too soon, return cache as is
| notDeepEnoughYet = (LinkedNode is lis children', cache') --not deep enough, recurse
| atDesiredDepth = processNode (LinkedNode is lis children) path cache
| otherwise = error "unexpected"
where
tooShortNoChildren = length path < desiredDepth && null children
notDeepEnoughYet = length path < desiredDepth
atDesiredDepth = length path == desiredDepth
(children', cache') = foldl g ([], cache) children -- we use fold because we need the first childs result for the second child
g :: ([LinkedTree], c a) -> LinkedTree -> ([LinkedTree], c a)
g (processed, cache) lt = (processed++[p2], c2)
where
p2 :: LinkedTree
c2 :: c a
(p2,c2) = (helper (head is:path)) cache lt
setPermutationLinks :: Int -> LinkedTree -> LinkedTree
setPermutationLinks = treeWalker processNode
where
processNode :: LinkedTree -> [Int] -> PathCache -> (LinkedTree, PathCache)
processNode ln@(LinkedNode is _ _) path cache = case query cache of
Nothing -> (ln, cacheAdd fullPath cache) --return node as is, add path to cache
Just cpath -> (Link is $ reverse cpath, cache) --return link, and unchanged cache
where
currentPathMatches (a:as) = head is == a -- both end in the same node
&& Set.fromList path == Set.fromList as -- remaining are identical
fullPath = Path $ head is:path
query :: PathCache -> Maybe [Int] -- its in the cache or not
query (ListCache []) = Nothing
query (ListCache (Path a:as)) | currentPathMatches a = Just a
| otherwise = query $ ListCache as
The above code works, but I found it hard to come up with and difficult to debug. Is there a clearer way of implementing this?
1 Answer 1
tooShortNoChildren
is subsumed in notDeepEnoughYet
. The newtypes and class are silly, discard them. treeWalker
doesn't touch cache
, so let's hide cache
in a monadic interface.
-- walks through all nodes of specified depth and applies a function
treeWalker :: ([Int] -> LinkedTree -> State [a] LinkedTree) -- function to apply on every node of desired depth
-> Int -> LinkedTree -> LinkedTree
treeWalker processNode desiredDepth = (`evalState` []) . helper [] processnode
where
helper :: Monad m => [Int] -- path up until now, excluding current node, [level k, ..., level 1, root]
-> ([Int] -> LinkedTree -> m LinkedTree)
-> LinkedTree -> m LinkedTree
helper _ cache (Link is p) = (Link is p, cache) -- there may be links already e.g. from a previous run with different parameters
helper path cache ln@(LinkedNode is lis children) =
if length path == desiredDepth
then processNode path cache ln
else LinkedNode is lis <$> traverse (helper (head is:path)) children
The explicit recursion has the form of a fold.
-- walks through all nodes of specified depth and applies a function
treeWalker :: ([Int] -> LinkedTree -> State [a] LinkedTree) -- function to apply on every node of desired depth
-> Int -> LinkedTree -> LinkedTree
treeWalker processNode desiredDepth = (`evalState` []) . foldr ($) processNode (replicate desiredDepth step) []
where
-- Makes a node processor work at one level deeper.
-- The path excludes the current node and has form [level k, ..., level 1, root].
step :: Monad m => ([Int] -> LinkedTree -> m LinkedTree) -> [Int] -> LinkedTree -> m LinkedTree
step _ _ l@(Link _ _) = return l
step f path ln@(LinkedNode is lis children) = LinkedNode is lis <$> traverse (f . (head is:)) children
I'd inline that.
I'll assume that comparing a:as
and head is:path
is enough. I'll also assume that as according to LinkedNode
s definition, its first parameter has type Int
, not [Int]
.
setPermutationLinks :: Int -> LinkedTree -> LinkedTree
setPermutationLinks desiredDepth = (`evalState` []) . foldr ($) processNode (replicate desiredDepth liftThroughTree) [] where
processNode :: [Int] -> LinkedTree -> State [[Int]] LinkedTree
processNode path ln@(LinkedNode i _ _) = gets (find $ (Set.fromList (i:path) ==) . Set.fromList) >>= \case
Nothing -> modify ((i:path):) >> return ln
Just cpath -> return $ Link i $ reverse cpath
liftThroughTree :: Monad m => ([Int] -> LinkedTree -> m LinkedTree) -> [Int] -> LinkedTree -> m LinkedTree
liftThroughTree _ _ l@(Link _ _) = return l
liftThroughTree f path ln@(LinkedNode i lis children) = LinkedNode i lis <$> traverse (f . (i:)) children
-
\$\begingroup\$ Yes,
LinkedNode
hasInt
as first parameter, I simplified my code while typing it out and forgot to change it there. Should be fixed now. \$\endgroup\$user2740– user27402020年01月15日 17:04:37 +00:00Commented Jan 15, 2020 at 17:04