I have implemented Lazy version of Prim's Minimum Spanning Tree Algorithm. I want to improve the code structure, follow prevalent conventions and reduce code size. I am solving Project Euler #107.
Explanation:
I do as follows for MST:
- Initialize an adjacency list, visited set with vertice #0, min-priority heap/queue for edges with shortest edge from vertice #0, sum of edges in MST as 0.
- Find next edge, if it is not a crossing edge, recurse.
- Otherwise I include that edge (by updating visited and heap/queue) and add to MST sum the edge's weight.
Note: Non-existent edges have been assigned value -1
.
Code:
import Data.Array
import qualified Data.Heap as Heap
import Data.List.Split
import Data.Maybe
import qualified Data.Set as Set
dim = 40 :: Int
main :: IO ()
main = print . maximumSaving . toAdjacencyMatrix . map (wordsBy (==',')) . lines
=<< readFile "txt/107.txt"
where
toAdjacencyMatrix mat = array ((0,0),(dim-1,dim-1))
[ ((i,j), if val == "-" then -1 else read val)
| i <- [0..dim-1], j <- [0..dim-1], let val = mat !! i !! j]
totalWeight network = sum [network ! (i,j) | i <- [0..dim-1], j <- [0..dim-1], i < j, network ! (i, j) > 0]
maximumSaving network = totalWeight network - minimumSpanningTreeEdgeSum network
minimumSpanningTreeEdgeSum :: Array (Int, Int) Int -> Int
minimumSpanningTreeEdgeSum adj =
minimumSpanningTree'
adj
(Set.singleton 0)
(Heap.fromList [(adj!(0,a),(0, a)) | a <- [1..dim-1], adj ! (0,a) > 0])
0
where
minimumSpanningTree' :: Array (Int, Int) Int -> Set.Set Int -> Heap.MinPrioHeap Int (Int, Int) -> Int -> Int
minimumSpanningTree' adj visited queue sm = case Heap.viewHead queue of
Nothing -> sm
(Just (weight, edge)) ->
if isCrossingEdge edge then
let sm' = sm + weight in
let nextVertice = if Set.notMember (fst edge) visited then fst edge else snd edge in
let visited' = Set.insert nextVertice visited in
let newEdges = [(i, j) | a <- [0..dim-1], let i = min nextVertice a, let j = max nextVertice a, Set.notMember i visited, adj ! (i,j) > 0 ] in
let queue'' = foldl (flip Heap.insert) queue' $ map (\e -> (adj ! e, e)) newEdges
in minimumSpanningTree' adj visited' queue'' sm'
else
minimumSpanningTree' adj visited queue' sm
where
queue' = fromJust $ Heap.viewTail queue
isCrossingEdge edge = Set.notMember (fst edge) visited ||
Set.notMember (snd edge) visited
1 Answer 1
I would separate the looping queue logic from the rest. A general implementation of unrollM
necessitates a state monad here.
import Control.Monad.State
import Control.Monad.Loops
unrollM :: (Monad m, Ord a) => (a -> m [a]) -> [a] -> m [a]
unrollM f = step . (Heap.fromList :: Ord a => [a] -> Heap.MinHeap a) where
step queue = case Heap.view queue of
Nothing -> pure []
Just (x, queue') -> do
as <- f x
(x:) <$> step (foldl (flip Heap.insert) queue' as)
minimumSpanningTreeEdgeSum :: Array (Int, Int) Int -> Int
minimumSpanningTreeEdgeSum adj = sum $ map fst $ (`evalState` Set.singleton 0) $ unrollM minimumSpanningTree'
[(adj!(0,a),(0, a)) | a <- [1..dim-1], adj ! (0,a) > 0] where
minimumSpanningTree' :: (Int, (Int, Int)) -> State (Set.Set Int) [(Int, (Int, Int))]
minimumSpanningTree' (weight, (from, to)) = firstM (gets . Set.notMember) [from, to] >>= \case
Nothing -> pure []
Just nextVertex -> do
modify $ Set.insert nextVertex
filterM (gets . Set.notMember . fst . snd) $ filter ((>0) . fst)
[ (adj ! (i,j), (i, j))
| a <- [0..dim-1]
, let [i,j] = sort [nextVertex, a]
]
Explore related questions
See similar questions with these tags.