So, this is a ton of code, but that's what I came up with for an efficient and extendable implementation of the A* search algorithm.
The first four classes can be seen as interfaces to show the user what the interface looks like. (Is this a good way of doing that?) I can also provide examples if needed or wanted.
"""
Tools and classes for finding the best path between two points.
"""
from functools import total_ordering
from utils.classes import Abstract
from utils.priority_queue import PriorityQueue
from utils.iterable import CacheDict
@total_ordering
class Position(Abstract):
def __hash__(self):
pass
def __eq__(self, other):
pass
def __lt__(self, other):
pass
def __repr__(self):
pass
class TiledMap(Abstract):
def iter_accessible_adjacent(self, position, for_mover):
pass
def is_accessible(self, position, for_mover):
pass
def get_cost(self, from_position, to_position, for_mover):
pass
class Heuristic(Abstract):
@staticmethod
def get_cost(from_position, to_position):
pass
class Mover(Abstract):
def can_move_on(self, tile):
pass
class Node(object):
def __init__(self, position):
self.position = position
self.movement_cost = 0
self.heuristic_cost = 0
self._distance_in_tiles = 0
self._predecessor = None
def set_predecessor(self, new_predecessor):
self._distance_in_tiles = new_predecessor.get_distance_in_tiles() + 1
self._predecessor = new_predecessor
def get_distance_in_tiles(self):
return self._distance_in_tiles
def get_predecessor(self):
return self._predecessor
def get_path_score(self):
return self.movement_cost + self.heuristic_cost
class AStarPathFinder(object):
"""
A reusable implementation of the A* search algorithm.
"""
def __init__(self, tiled_map, heuristic, max_distance_in_tiles):
self._tiled_map = tiled_map
self._heuristic = heuristic
self._max_distance_in_tiles = max_distance_in_tiles
self.set_observe_function(lambda node: None)
self._reset_path_data()
def set_observe_function(self, func):
self._observe = func
def _reset_path_data(self):
self._closed = set()
self._open = PriorityQueue(key=lambda node: node.get_path_score())
self._nodes = CacheDict(Node)
def get_path(self, for_mover, start_position, target_position):
self._for_mover = for_mover
self._start_position = start_position
self._target_position = target_position
path = []
if self._is_target_accessible():
if self._has_found_path():
path = self._get_retraced_path()
self._reset_path_data()
return path
def _is_target_accessible(self):
return self._tiled_map.is_accessible(self._target_position, self._for_mover)
def _has_found_path(self):
current_node = self._nodes[self._start_position]
self._open.add(current_node)
while self._open and self._is_within_max_range(current_node):
current_node = self._open.pop()
self._closed.add(current_node.position)
if self._target_position in self._closed:
return True
for at_position in self._iter_accessible_adjacent(current_node):
self._evaluate_adjacent(current_node, at_position)
return False
def _is_within_max_range(self, current_node):
return current_node.get_distance_in_tiles() < self._max_distance_in_tiles
def _iter_accessible_adjacent(self, node):
return self._tiled_map.iter_accessible_adjacent(node.position, self._for_mover)
def _get_retraced_path(self):
path = []
node = self._nodes[self._target_position]
while node.position != self._start_position:
path.append(node.position)
node = node.get_predecessor()
path.reverse()
return path
def _evaluate_adjacent(self, current_node, position):
node = self._nodes[position]
tentative_movement_cost = self._get_movement_cost(current_node, node)
if tentative_movement_cost < node.movement_cost:
if node in self._open:
self._open.remove(node)
elif position in self._closed:
self._closed.remove(position)
if position not in self._closed and node not in self._open:
node.set_predecessor(current_node)
node.movement_cost = tentative_movement_cost
node.heuristic_cost = self._get_heuristic_cost(position)
self._open.add(node)
self._observe(node)
def _get_movement_cost(self, from_node, to_node):
new_movement_cost = self._tiled_map.get_cost(from_node.position, to_node.position, self._for_mover)
return from_node.movement_cost + new_movement_cost
def _get_heuristic_cost(self, position):
return self._heuristic.get_cost(position, self._target_position)
The code for CacheDict
that is used by AStarPathFinder
:
class CacheDict(dict):
"""
Acts like a normal dict, but when trying to get an item by a key that
is not contained, the provided function is called with the key as the
argument(s) and it's return value is stored with the key.
value that is returned by the invocation of the
provided function with the key as the argument is stored for the key.
If the key already exists in the dictionary, the stored value is returned.
>>> def very_slow_function(x):
... a, b = x
... print('very_slow_function is executed')
... return (a * b) + a + b
>>> cd = CacheDict(very_slow_function)
>>> cd[(3,4)]
very_slow_function is executed
19
>>> cd[(3,4)]
19
>>> cd
{(3, 4): 19}
"""
def __init__(self, func):
self.func = func
def __missing__(self, key):
ret = self[key] = self.func(key)
return ret
The PriorityQueue
looks like this: (Mostly the rewritten version by Gareth Rees from my previous code review)
from heapq import heappush, heappop
class PriorityQueue(object):
"""
A priority queue with O(log n) addition, O(1) membership test and
amortized O(log n) removal.
The `key` argument specifies a function that returns the score for an
element in the priority queue. (If not supplied, an element is its own score).
>>> q = PriorityQueue([3, 1, 4])
>>> q.pop()
1
>>> q.add(2); q.pop()
2
>>> q.remove(3); q.pop()
4
>>> list(q)
[]
>>> bool(q)
False
>>> q.pop()
Traceback (most recent call last):
...
IndexError: index out of range
>>> q = PriorityQueue('length of text'.split(), key = lambda s:len(s))
>>> q.pop()
'of'
"""
def __init__(self, *args, **kwargs):
self._key = kwargs.pop('key', lambda x:x)
self._heap = []
self._dict = {}
if args:
for elem in args[0]:
self.add(elem)
def __nonzero__(self):
return bool(self._dict)
def __iter__(self):
return iter(self._dict)
def __contains__(self, element):
return element in self._dict
def add(self, element):
"""
Add an element to the priority queue.
"""
e = PriorityQueueElement(element, self._key(element))
self._dict[element] = e
heappush(self._heap, e)
def remove(self, element):
"""
Remove an element from the priority queue.
If the element is not a member, raise KeyError.
"""
e = self._dict.pop(element)
e.removed = True
def pop(self):
"""
Remove and return the element with the smallest score from the
priority queue.
"""
while True:
e = heappop(self._heap)
if not e.removed:
del self._dict[e.element]
return e.element
class PriorityQueueElement(object):
"""
A proxy for an element in a priority queue that remembers (and
compares according to) its score.
"""
def __init__(self, element, score):
self.element = element
self._score = score
self.removed = False
def __lt__(self, other):
return self._score < other._score
What could be simpler, more efficient or worded better?
-
\$\begingroup\$ This is a lot of code. Would you be interested in putting on github and making a link? That might make it easier for people to recommend changes/improvements, too. \$\endgroup\$user809695– user8096952013年10月17日 05:17:42 +00:00Commented Oct 17, 2013 at 5:17
-
\$\begingroup\$ Good idea, I'll probably do that later. Bitbucket is okay too, right? (hg is easier to use.) \$\endgroup\$Joschua– Joschua2013年10月26日 18:10:02 +00:00Commented Oct 26, 2013 at 18:10
-
\$\begingroup\$ Yeah of course; any common version control system would be fine. \$\endgroup\$user809695– user8096952013年10月29日 04:23:45 +00:00Commented Oct 29, 2013 at 4:23
1 Answer 1
You write, "The [abstract] classes can be seen as interfaces to show the user what the interface looks like." But none of these abstract classes have any documentation, so how are people expected to figure out how to use them? I mean, in some cases we might be able to figure it out by reading the code, but what about this method:
def iter_accessible_adjacent(self, position, for_mover): pass
Presumably this is supposed to return an iterator over the positions that are adjacent to
position
, and accessible from it, but what are we supposed to pass for thefor_mover
argument?Why do
Position
objects need to have a total order? You shouldn't need to compare them, and coordinates have no natural ordering anyway. Also, why do they need a__repr__
method? You don't seem to callrepr
anywhere in this code.In several cases you could provide actual implementations of the methods, instead of just writing
pass
. For example, if the idea of thefor_mover
argument tois_accessible
is that it's supposed to be aMover
object, you could write:def is_accessible(self, position, for_mover): return for_mover.can_move_on(position)
Indeed, it's not clear to me why anyone would want to implement it any other way. Similarly,
Heuristic.get_cost
could be:def get_cost(from_position, to_position): return 0
Since 0 is always an admissible estimate.
Is it important that
Mover.can_move_on
takes atile
argument rather than aposition
. What is a tile, anyway?The
Heuristic
class seems useless: it doesn't have any instance methods. Why not just use a function?The
Heuristic.get_cost
method seems poorly named: the A* algorithm uses an admissible estimate of the cost, not the cost itself.The
TiledMap
class seems poorly named: it doesn't seem to have anything to do with tiles or maps. What it knows about is which positions are adjacent, and the cost of travelling between adjacent positions. This kind of data structure is normally known as a weighted graph.Similarly, the
Node
andAStarPathFinder
classes talk about "distance in tiles" but really you are talking about the number of edges in a path in a weighted graph.Your A* algorithm takes a
max_distance_in_tiles
. But what if I don't want to specify a maximum distance? For the class to be really re-usable, there ought to be a way of avoiding having to specify this.Instead of writing your own
Abstract
class, why not use the built-inabc.ABCMeta
? Similarly, why not decorate your abstract methods with theabc.abstractmethod
decorator? Then you'd get useful exceptions if anyone ever tried to instantiate an abstract base class.Position
could inherit fromcollections.abc.Hashable
instead of defining its own abstract__hash__
method.In the
PriorityQueue
class, it would be more general to implement__len__
than__nonzero__
(also, the latter is Python 2 only):def __len__(self): return len(self._heap)
The name
TiledMap
suggests that you are going to be looking for paths in a unit-cost grid (as in many 2-dimensional tile-based games). But in this use case, jump point search is a much better algorithm than A*.
-
\$\begingroup\$ Thank you so much for your review. 4. A tile is something that defines the space at the given position (e.g. there's a Stone so it's inaccessible). I could of course pass the position to the mover, but then the mover would have to ask the
Map
what's there. \$\endgroup\$Joschua– Joschua2013年11月15日 14:38:35 +00:00Commented Nov 15, 2013 at 14:38 -
\$\begingroup\$ 2, 3, 5-12 are done. Also, I'm unsure about 7. - see question in previous comment. Regarding 6. I think about putting get_estimate (
""" Returns the admissible heuristic estimate for the given positions. """
) into the Position class - is this a good idea? \$\endgroup\$Joschua– Joschua2013年11月15日 15:18:40 +00:00Commented Nov 15, 2013 at 15:18
Explore related questions
See similar questions with these tags.