I implemented the Dijkstra algorithm. I'm looking for feedback how to make this more pythonic. I'm also wondering if I should move get_shortest_path
method out of the Graph class, this would mean I need to expose the vertex list.
The MutablePriorityQueue is just the code snippet from: https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes. So I can update the priority of an item in the queue.
import math
from mutable_priority_queue import MutablePriorityQueue
class Graph:
def __init__(self):
self._vertices = []
def add_new_vertex(self, x, y):
"""
Adds a new vertex to the graph.
:param x: X position of the vertex.
:param y: Y position of the vertex.
:return: The newly added vertex in the graph.
"""
new_vertex_index = len(self._vertices)
new_vertex = Vertex(new_vertex_index, x, y)
self._vertices.append(new_vertex)
return new_vertex
def get_vertex(self, i):
"""
Returns the vertex at the i'th index.
:param i: The index of the vertex in our vertex list.
"""
return self._vertices[i]
@staticmethod
def _calculate_distance(v0, v1):
v_x = v0.x - v1.x
v_y = v0.y - v1.y
return math.sqrt(v_x * v_x + v_y * v_y)
@staticmethod
def _get_path(vertex):
path = [vertex]
previous_vertex = vertex.previous
while previous_vertex is not None:
path.insert(0, previous_vertex)
previous_vertex = previous_vertex.previous
return path
def get_shortest_path(self, source_vertex_index, destination_vertex_index):
"""
Calculates the shortest path between source and destination using Dijkstra's algorithm
:param source_vertex_index: The index of the vertex we start at.
:param destination_vertex_index: The index of the vertex we want to calculate a path to.
:return: A collection with the vertices making up the shortest path from source to destination.
"""
if source_vertex_index > len(self._vertices) - 1:
raise IndexError('source vertex index is out of range.')
if destination_vertex_index > len(self._vertices) - 1:
raise IndexError('destination vertex index is out of range.')
priority_queue = MutablePriorityQueue()
visited_vertices = set()
# The source is 0 distance away from itself.
source_vertex = self._vertices[source_vertex_index]
source_vertex.distance = 0
priority_queue.add_or_update(source_vertex, 0)
while priority_queue:
# Find an unvisited vertex that closest to our source vertex.
# Note: the first loop this will be our source vertex.
current_vertex = priority_queue.pop()
if current_vertex.index == destination_vertex_index:
return Graph._get_path(current_vertex)
visited_vertices.add(current_vertex)
# Loop over all the neighbouring vertices of our current vertex
for neighbour_index in range(current_vertex.degree):
neighbour_vertex = current_vertex.get_neighbour(neighbour_index)
# If we already visited this neighbour of the current vertex that means
# we have already calculated the distance between the two.
if neighbour_vertex in visited_vertices:
continue
# Calculate the total distance from the source to this current vertex
distance = Graph._calculate_distance(current_vertex, neighbour_vertex)
tentative_distance = current_vertex.distance + distance
# If the distance is lower we have found a more direct path (or this neighbour hadn't been visited yet)
if tentative_distance < neighbour_vertex.distance:
neighbour_vertex.distance = tentative_distance
neighbour_vertex.previous = current_vertex
priority_queue.add_or_update(neighbour_vertex, tentative_distance)
return Graph._get_path(current_vertex)
class Vertex:
def __init__(self, index, x, y):
if not isinstance(index, int):
raise TypeError('Index needs to be of type integer.')
if index < 0:
raise IndexError('Index out of range (-1).')
self.index = index
self.x = x
self.y = y
self.distance = float("inf")
self.previous = None
self._neighbours = []
self._degree = 0
@property
def degree(self):
"""
:return: The number of edges connected to this vertex.
"""
return self._degree
def get_neighbour(self, index):
"""
:param index: The 0-based index of our neighbour.
:return: The neighbour vertex at the provided index.
"""
return self._neighbours[index]
def create_edge(self, neighbour_vertex):
"""
Creates and edge between this vertex and the neighbour.
:param neighbour_vertex: The vertex we create an edge between.
"""
if neighbour_vertex in self._neighbours:
return
self._neighbours.append(neighbour_vertex)
self._degree += 1
neighbour_vertex.create_edge(self)
Unit tests:
from unittest import TestCase
from undirected_graph import Graph, Vertex
class TestUndirectedGraph(TestCase):
def test_negative_vertex_index(self):
# arrange & act & assert
self.assertRaises(IndexError, lambda: Vertex(-1, 0, 0))
def test_add_neighbour_check_degree(self):
# arrange
v0 = Vertex(0, 10, 10)
v1 = Vertex(1, 20, 20)
# act
v0.create_edge(v1)
# assert
self.assertEqual(1, v0.degree)
self.assertEqual(1, v1.degree)
self.assertEqual(v1, v0.get_neighbour(0))
self.assertEqual(v0, v1.get_neighbour(0))
def test_add_first_vertex(self):
# arrange
graph = Graph()
# act
vertex = graph.add_new_vertex(10, 10)
# assert
self.assertEqual(0, vertex.index)
def test_add_two_vertices(self):
# arrange
graph = Graph()
# act
vertex0 = graph.add_new_vertex(10, 10)
vertex1 = graph.add_new_vertex(20, 20)
# assert
self.assertEqual(0, vertex0.index)
self.assertEqual(1, vertex1.index)
def test_distance_empty_graph(self):
# arrange
graph = Graph()
# act & assert
self.assertRaises(IndexError, lambda: graph.get_shortest_path(0, 0))
def test_distance_single_vertex(self):
# arrange
graph = Graph()
graph.add_new_vertex(0, 0)
# act
path = graph.get_shortest_path(0, 0)
# assert
self.assertEqual(len(path), 1)
self.assertEqual(path[0], graph.get_vertex(0))
def test_distance_two_vertices(self):
# arrange
graph = Graph()
v0 = graph.add_new_vertex(0, 0)
v1 = graph.add_new_vertex(10, 0)
v0.create_edge(v1)
# act
path = graph.get_shortest_path(0, 1)
# assert
self.assertEqual(len(path), 2)
self.assertEqual(path[0], graph.get_vertex(0))
self.assertEqual(path[1], graph.get_vertex(1))
1 Answer 1
My comments are not necessarily python specific but just questioning some modelling choices.
Why does vertex take index as an argument? The array of vertices belongs to the graph, is maintained and manipulated by the graph. The question "what is the index of this vertex" should always be answered by the graph object.
In general, the index should be something internal to graph. You don't want to be passing around the index of vertices everywhere, but the vertices themselves.
path = graph.get_shortest_path(v0, v1)
As you write bigger programs, keeping track of the index of each vertex at some higher level of abstraction (the thing that creates the graph object) is probably going to be cumbersome. Also, then you can never sort or reorder the list of vertice or remove a vertex from the list because you'll screw up the indexing. You might have a similar problem with the neighbor list in the vector class.
You could expose neighbors in the vector class. Makes looping simpler:
for n in current_vertex.neighbors: