I've got the following code to find the shortest path between two nodes using almost-Dijkstra's-algorithm which was written for an exercise, and I'm looking for things to improve my general Scala style. What I currently have:
object Graphs {
case class Node[T](value : T)
case class Edge[T](from : Node[T], to : Node[T], dist : Int)
def shortest[T](edges : Set[Edge[T]], start : T, end : T) : Option[List[Edge[T]]] = {
val tentative = edges.flatMap(e => Set(e.from, e.to)).map(n => (n, if (n.value == start) Some(List.empty[Edge[T]]) else None )).toMap
def rec(tentative : Map[Node[T], Option[List[Edge[T]]]]) : Option[List[Edge[T]]] = {
val current = tentative.collect{ case (node, Some(route)) => (node, route)}.toList.sortBy(_._2.length).headOption
current match {
case None => None
case Some((node, route)) => {
if (node.value == end) Some(route)
else {
val fromHere = edges.filter(e => e.from == node)
val tentupdates = for(edge <- fromHere if tentative.contains(edge.to)) yield {
tentative.get(edge.to) match {
case None => throw new Error("broken algorithm")
case Some(Some(knownroute)) if (knownroute.map(_.dist).sum < route.map(_.dist).sum + edge.dist) => (edge.to, Some(knownroute))
case _ =>(edge.to, Some(edge :: route))
}
}
val newtentative = (tentative ++ tentupdates) - node
rec(newtentative)
}
}
}
}
rec(tentative)
}
}
First of, getting some feedback on the correctness would be nice.
For the algorithm itself, I already know it could be refined by keeping track of the unevaluated edges, and for a more general solution keeping a second accumulator with the solved set wouldn't cost me much more, but I get the general idea of how to implement that.
I'm thinking of replacing Node with Scalaz TypeTags, and would like some feedback on whether that's a good idea.
Other than that, I'd like some feedback on general style, and how to improve readability - I have quite long lines now for example. Also, types like Map[Node[T], Option[List[Edge[T]]]]
sort of hurt my eyes, I'd love to know how I could improve that.
Lastly, I really don't like my
case Some(Some(knownroute)) if (knownroute.map(_.dist).sum < route.map(_.dist).sum + edge.dist) => (edge.to, Some(knownroute))
case _ => (edge.to, Some(edge :: route))
the first case is only used as a filter, I'm actually looking for something like
case None || Some(Some(knownroute)) if (knownroute.map(_.dist).sum > route.map(_.dist).sum + edge.dist) => (edge.to, Some(edge :: route))
but I don't know how to express that.
1 Answer 1
There are a few things how this could be improved.
First, a few minor quibbles about syntax:
When using a type annotation, do not put a space before the
:
–name: Type
, please. (source)In a chain of higher-order methods, do not use the
.
to invoke the method. For example thisval tentative = edges.flatMap(e => Set(e.from, e.to)).map(n => (n, if (n.value == start) Some(List.empty[Edge[T]]) else None )).toMap
should be
val tentative = edges flatMap (e => Set(e.from, e.to)) map (n => (n, if (n.value == start) Some(List.empty[Edge[T]]) else None)).toMap
(source)
When a chain of transformations is very long, splitting inside the lambdas can be an acceptable solution:
val tentative = edges flatMap ( e => Set(e.from, e.to) ) map (n => (n, if (n.value == start) Some(List.empty[Edge[T]]) else None) ).toMap
yield { block }
is "evil" and should be avoided.for
-comprehensions can be rewritten less clearly withflatMap
andfilter
, but this may actually be preferable when the transformations are deeply nested.
Now let's look at this piece of your code:
for(edge <- fromHere if tentative.contains(edge.to)) yield {
tentative.get(edge.to) match {
case None => throw new Error("broken algorithm")
case Some(Some(knownroute)) if (knownroute.map(_.dist).sum < route.map(_.dist).sum + edge.dist) => (edge.to, Some(knownroute))
case _ =>(edge.to, Some(edge :: route))
}
}
As I said, this could be rewritten to avoid the comprehension.
fromHere filter (edge => tentative.contains(edge.to)) flatMap { edge =>
tentative.get(edge.to) match {
case None => throw new Error("broken algorithm")
case Some(Some(knownroute)) if (knownroute.map(_.dist).sum < route.map(_.dist).sum + edge.dist) => (edge.to, Some(knownroute))
case _ => (edge.to, Some(edge :: route))
}
}
The tentative.get(...)
returns an Option
, which will be None
if no element for that key was found. But that means we can get rid of the filter
! Instead, we map
over the result of the get
, which removes one level of Option
s:
fromHere flatMap { edge =>
tentative.get(edge.to) map {
case Some(knownroute) if (knownroute.map(_.dist).sum < route.map(_.dist).sum + edge.dist) => (edge.to, Some(knownroute))
case _ => (edge.to, Some(edge :: route))
}
}
Destructuring the Option
with Pattern matching is a bit tedious. Actually, we want to do one thing, orElse
some default case.
fromHere flatMap { edge =>
tentative.get(edge.to) map { maybeRoute =>
maybeRoute filter (knownroute =>
knownroute.map(_.dist).sum < route.map(_.dist).sum + edge.dist
) map (knownroute =>
Pair(edge.to, Some(knownroute)
) getOrElse (Pair(edge.to, Some(edge :: route)))
}
}
But this is an unreadable mess! Yes, it somehow is. We can improve this by adding a Route
class, e.g:
case class Route[T](route: List[Edge[T]], dist: Int) {
val length = route.length
def this(route: List[Edge[T]]) = this(route, route map (_.dist) sum)
def this() = this(List.empty[Edge[T]], 0)
def ::(edge: Edge[T]) = Route(edge :: route, dist + edge.dist)
}
val tentative: Map[Node[T], Option[Route[T]] = ...
The main advantage is that this keeps track of a route's distance, which means the above code becomes the slightly more accessible
fromHere flatMap { edge =>
tentative.get(edge.to) map { maybeRoute =>
val maybeBetterRoute =
maybeRoute filter (knownRoute => knownRoute.dist < route.dist + edge.dist)
maybeBetterRoute map (knownRoute =>
Pair(edge.to, Some(knownroute)
) getOrElse (
Pair(edge.to, Some(edge :: route))
)
}
}
You calculate the fromHere
each time, which is an O(n) calculation, I think:
val fromHere = edges.filter(e => e.from == node)
It may be better to build a Map[Node[T], List[Edge[T]]]
before the recursion, which is also a fairly cheap operation:
val edgesBySource = edges groupBy (_.from)
then: val fromHere = edgesBySource.get(node).flatten
. I assume this would pay off soon for non-tiny graphs.
-
\$\begingroup\$ great stuff! I'll get cracking on it! \$\endgroup\$Martijn– Martijn2014年03月13日 13:40:32 +00:00Commented Mar 13, 2014 at 13:40
route.map(_.distance).sum
. I know this is inefficient if the routes are quite large. I considered keeping aTuple2[List[Edge[T], Int]
rather than just aList[Edge[T]
, but I figured that was trivial anyway, and would clutter the clarity of the implementation (a production version would have it, but this is basically a toy). If you would like to propose changes to the algorithm, by all means do! \$\endgroup\$.sortBy(_._2.length)
is not optimal. It should be adding the weights together and choosing the next candidate based on the shortest weighted path. \$\endgroup\$