I have translated Dijkstra's algorithms (uni- and bidirectional variants) from Java to Python, eventually coming up with this:
Dijkstra.py
import heapq
from Digraph import Digraph
from random import choice
from random import uniform
from time import time
__author__ = 'Rodion "rodde" Efremov'
class HeapEntry:
def __init__(self, node, priority):
self.node = node
self.priority = priority
def __lt__(self, other):
return self.priority < other.priority
def traceback_path(target, parents):
path = []
while target:
path.append(target)
target = parents[target]
return list(reversed(path))
def bi_traceback_path(touch_node, parentsa, parentsb):
path = traceback_path(touch_node, parentsa)
touch_node = parentsb[touch_node]
while touch_node:
path.append(touch_node)
touch_node = parentsb[touch_node]
return path
def dijkstra(graph, source, target):
open = [HeapEntry(source, 0.0)]
closed = set()
parents = dict()
distance = dict()
parents[source] = None
distance[source] = 0.0
while open:
top_heap_entry = heapq.heappop(open)
current = top_heap_entry.node
if current == target:
return traceback_path(target, parents)
closed.add(current)
for child in graph.get_children_of(current):
if child in closed:
continue
tentative_cost = distance[current] + graph.get_arc_weight(current, child)
if child not in distance.keys() or distance[child] > tentative_cost:
distance[child] = tentative_cost
parents[child] = current
heap_entry = HeapEntry(child, tentative_cost)
heapq.heappush(open, heap_entry)
return [] # Target not reachable from source, return empty list.
def bidirectional_dijkstra(graph, source, target):
opena = [HeapEntry(source, 0.0)]
openb = [HeapEntry(target, 0.0)]
closeda = set()
closedb = set()
parentsa = dict()
parentsb = dict()
distancea = dict()
distanceb = dict()
best_path_length = {'value': 1e9}
touch_node = {'value': None}
parentsa[source] = None
parentsb[target] = None
distancea[source] = 0.0
distanceb[target] = 0.0
def update_forward_frontier(node, node_score):
if node in closedb:
path_length = distanceb[node] + node_score
if best_path_length['value'] > path_length:
best_path_length['value'] = path_length
touch_node['value'] = node
def update_backward_frontier(node, node_score):
if node in closeda:
path_length = distancea[node] + node_score
if best_path_length['value'] > path_length:
best_path_length['value'] = path_length
touch_node['value'] = node
def expand_forward_frontier():
current = heapq.heappop(opena).node
closeda.add(current)
for child in graph.get_children_of(current):
if child in closeda:
continue
tentative_score = distancea[current] + graph.get_arc_weight(current, child)
if child not in distancea.keys() or tentative_score < distancea[child]:
distancea[child] = tentative_score
parentsa[child] = current
heapq.heappush(opena, HeapEntry(child, tentative_score))
update_forward_frontier(child, tentative_score)
def expand_backward_frontier():
current = heapq.heappop(openb).node
closedb.add(current)
for parent in graph.get_parents_of(current):
if parent in closedb:
continue
tentative_score = distanceb[current] + graph.get_arc_weight(parent, current)
if parent not in distanceb.keys() or tentative_score < distanceb[parent]:
distanceb[parent] = tentative_score
parentsb[parent] = current
heapq.heappush(openb, HeapEntry(parent, tentative_score))
update_backward_frontier(parent, tentative_score)
while opena and openb:
tmp = distancea[opena[0].node] + distanceb[openb[0].node]
if tmp >= best_path_length['value']:
return bi_traceback_path(touch_node['value'], parentsa, parentsb)
if len(opena) + len(closeda) < len(openb) + len(closedb):
expand_forward_frontier()
else:
expand_backward_frontier()
return []
def create_random_digraph(nodes, arcs, max_weight):
graph = Digraph()
node_list = []
for node in range(nodes):
graph.add_node(node)
node_list.append(node)
for _ in range(arcs):
weight = uniform(0.0, max_weight)
graph.add_arc(choice(node_list),
choice(node_list),
weight)
return graph, node_list
def path_cost(graph, path):
cost = 0.0
for i in range(len(path) - 1):
tail = path[i]
head = path[i + 1]
if not graph.has_arc(tail, head):
raise Exception("Not a path.")
cost += graph.get_arc_weight(tail, head)
return cost
def main():
graph, node_list = create_random_digraph(1000000, 5000000, 10.0)
source = choice(node_list)
target = choice(node_list)
del node_list[:]
print("Source:", source)
print("Target:", target)
start_time = time()
path1 = dijkstra(graph, source, target)
end_time = time()
print("Dijkstra's algorithm in", 1000.0 * (end_time - start_time), "milliseconds.")
start_time = time()
path2 = bidirectional_dijkstra(graph, source, target)
end_time = time()
print("Bidirectional Dijkstra's algorithm in", 1000.0 * (end_time - start_time), "milliseconds.")
print("Paths are identical:", path1 == path2)
print("Dijkstra path:")
for node in path1:
print(node)
print("Path length:", path_cost(graph, path1))
print("Bidirectional path:")
for node in path2:
print(node)
print("Path length:", path_cost(graph, path2))
if __name__ == "__main__":
main()
You can find the data structure for representing the graph here.
Performance figures can be as optimistic as these:
Dijkstra's algorithm in 37223.119020462036 milliseconds. Bidirectional Dijkstra's algorithm in 93.41907501220703 milliseconds. Paths are identical: True
Please, tell me anything that comes to mind.
1 Answer 1
from random import choice from random import uniform
If you are importing two things from the same module, put them on one line:
from random import choice, uniform
Your imports aren't in the order defined by PEP 8:
Imports should be grouped in the following order:
standard library imports related third party imports local application/library specific imports
You should put a blank line between each group of imports.
Put any relevant
__all__
specification after the imports.
Your imports should look like this:
import heapq
from random import choice, uniform
from time import time
from Digraph import Digraph
First, the standard library imports heapq
, random
, and time
are put in alphabetical order. Then a blank line between that group and the local application specific import, Digraph
.
def traceback_path(target, parents): path = [] while target: path.append(target) target = parents[target] return list(reversed(path))
I would use return path[::-1]
. That way, you stay with a list instead of creating a reversed
object and then converting that to a list.
if child not in distance.keys() ...
The __contains__
method of a dictionary already searches through the keys. Using .keys()
just puts extra information into memory that isn't needed. Just do if child not in distance ...
You use that pattern in several places.
best_path_length = {'value': 1e9} touch_node = {'value': None}
It looks like you have these as dictionaries so that you can change the values in sub-functions. Since they appear to be related closely enough that they are always modified at the same time, I would suggest combining them:
info = {'length': 1e9, 'node': None}
update_forward_frontier()
and update_backward_frontier()
are almost identical. I would suggest merging them:
def update_frontier(node, node_score, closed):
if node in closed:
...
You could do something similar for the expand_..._frontier()
functions.
node_list = [] ... node_list.append(node)
Your Digraph
class already keeps track of a node list. If you want to get a list from it, use node_list = list(graph.nodes)
since the order doesn't appear to matter.
raise Exception("Not a path.")
The exception class used should give some sort of indication of what type of error it is. Using Exception
doesn't do that. You should either use a more specific standard exception class such as ValueError
or define your own.
print("Bidirectional Dijkstra's algorithm in", 1000.0 * (end_time - start_time), "milliseconds.")
That's 97 characters not including the indentation. PEP 8 says:
Limit all lines to a maximum of 79 characters.
I would make your code look more like this:
m_secs = 1000.0 * (end_time - start_time)
print("Bidirection Dijkstra's algorithm in {} milliseconds.".format(m_secs))
That's 76 characters without the indentation. It's a little long, but it's a long string that you're printing.
print("Paths are identical:", path1 == path2)
Won't the user find it a little string that it prints Paths are identical: True
instead of Paths are identical.
? I might do something like this:
print("Paths are {}identical".format("not " * path1 != path2))
or:
print("Paths are {}identical".format("" if path1 == path2 else "not "))
-
\$\begingroup\$
path.insert(0, target)
How Python lists are implemented? Array-based or linked list? \$\endgroup\$coderodde– coderodde2016年04月02日 12:20:49 +00:00Commented Apr 2, 2016 at 12:20 -
-
\$\begingroup\$ For that very reason
insert(0, target)
will degrade to \$\Theta(k^2)\$ where \$k\$ is the amount of nodes in the path. \$\endgroup\$coderodde– coderodde2016年04月02日 12:33:40 +00:00Commented Apr 2, 2016 at 12:33 -
\$\begingroup\$ @coderodde: I defer. I will remove that in just a minute.
return path[::-1]
is, I believe, what you were looking for. \$\endgroup\$zondo– zondo2016年04月02日 12:34:47 +00:00Commented Apr 2, 2016 at 12:34 -
\$\begingroup\$ Not that it makes much difference which one to use, since often \$k\$ is small. \$\endgroup\$coderodde– coderodde2016年04月02日 12:35:48 +00:00Commented Apr 2, 2016 at 12:35
Explore related questions
See similar questions with these tags.