I'm trying to solve 15 puzzle using A* algorithm, but something really bad goes on in my get_solution()
function that ruins performance. I guess there is a too much usage of maps in here, but I don't understand why it slows down my program so much. What do you think?
I would be really happy if you could review my coding style as well.
import random
# Class represents playing desk
class Desk(object):
SHUFFLE_NUMBER = 20 # changing to 200 and higher ruins everything
def __init__(self, width, height):
self.matrix =[]
for i in range(height):
row = [x + 1 for x in range(i * width, (i+1) * width)]
self.matrix.append(row)
self.matrix[height - 1][ width - 1] = 0
def height(self):
return len(self.matrix)
def width(self):
return len(self.matrix[0])
def __str__(self):
str_list = []
for r in self.matrix:
for c in r:
str_list.append(str(c) + "\t")
str_list.append("\n")
str_list.pop()
return "".join(str_list)
def __eq__(self, other):
if (self.width() != other.width() or self.height() != other.height()):
return False
for r in range(self.height()):
for c in range(self.width()):
if self.matrix[r][c] != other.matrix[r][c]:
return False;
return True
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.__str__())
def shuffle(self):
for i in range(Desk.SHUFFLE_NUMBER):
self.matrix = self.neighbors()[random.randint(0, len(self.neighbors()) - 1)].matrix
def get_element(self, row, col):
return self.matrix[row][col]
def set_element(self, row, col, value):
self.matrix[row][col] = value
def copy(self):
newDesk = Desk(self.width(), self.height())
for r in range(self.height()):
for c in range(self.width()):
newDesk.set_element(r, c, self.matrix[r][c])
return newDesk
def heuristic_cost(self):
totalSum = 0
for r in range(self.height()):
for c in range(self.width()):
n = self.matrix[r][c] - 1
if (n == -1):
n = self.width() * self.height() - 1
r_solved = n / self.height()
c_solved = n % self.width()
totalSum += abs(r - r_solved)
totalSum += abs(c - c_solved)
return totalSum
def swap(self, r1, c1, r2, c2):
term = self.matrix[r1][c1]
self.matrix[r1][c1] = self.matrix[r2][c2]
self.matrix[r2][c2] = term
def neighbors(self):
neighbors = []
w = self.width()
h = self.height()
for r in range(h):
for c in range(w):
if (self.matrix[r][c] == 0):
if (r != 0):
neighbor = self.copy()
neighbor.swap(r, c, r - 1, c)
neighbors.append(neighbor)
if (r != h - 1):
neighbor = self.copy()
neighbor.swap(r, c, r + 1, c)
neighbors.append(neighbor)
if (c != 0):
neighbor = self.copy()
neighbor.swap(r, c, r, c - 1)
neighbors.append(neighbor)
if (c != w - 1):
neighbor = self.copy()
neighbor.swap(r, c, r, c + 1)
neighbors.append(neighbor)
return neighbors
# Class represents the game
class Puzzle15(object):
def __init__(self, width=4, height=4):
self.desk = Desk(width, height)
self.desk.shuffle()
self.steps = 0
def __str__(self):
return str(self.desk)
def __repr__(self):
return str(self.desk)
def lowest_score_element(self, openset, score):
min_score = 2**30
min_elem = None
for elem in openset:
if (elem in score.keys()):
if (score[elem] < min_score):
min_elem = elem
min_score = score[elem]
return min_elem
def get_solution(self):
start = self.desk.copy()
goal = Desk(self.desk.width(), self.desk.height())
closed_set = []
openset = [start]
came_from = {}
g_score = { start: 0 }
f_score = { start: g_score[start] + start.heuristic_cost()}
while len(openset) != 0:
current = self.lowest_score_element(openset, f_score)
if (current == goal):
return self.reconstruct_path(came_from, current)
openset.remove(current)
closed_set.append(current)
neighbors = current.neighbors()
for neighbor in neighbors:
tentative_g_score = g_score[current] + 1
tentative_f_score = tentative_g_score + neighbor.heuristic_cost()
if neighbor in closed_set and f_score.has_key(neighbor) and tentative_f_score >= f_score[neighbor]:
continue
if neighbor not in openset or (f_score.has_key(neighbor) and tentative_f_score < f_score[neighbor]):
came_from[neighbor] = current
g_score[neighbor] = tentative_g_score
f_score[neighbor] = tentative_f_score
if neighbor not in openset:
openset.append(neighbor)
self.steps += 1
return None
def reconstruct_path(self, came_from, current_node):
if (came_from.has_key(current_node)):
p = self.reconstruct_path(came_from, came_from[current_node])
return p + [current_node]
else:
return [current_node]
if __name__ == '__main__':
puzzle = Puzzle15(3,3)
solution = puzzle.get_solution()
print puzzle.steps
for s in solution:
print s
print
1 Answer 1
1. You can't fix what you can't measure
In order to improve the performance of your code, we need to be able to measure its performance, and that's hard to do, because your Puzzle15
class randomly shuffles the Desk
associated with each instance, so it is not easy to set up and carry out a systematic test.
Let's fix that, by changing Puzzle15.__init__
so that it takes a Desk
instance of our choosing:
def __init__(self, desk):
self.desk = desk
self.steps = 0
Now we can create some test cases:
def puzzle15_test_cases(n, width=4, height=4):
"""Generate 'n' pseudo-random (but repeatable) test cases."""
random.seed(1252481602)
for _ in range(n):
desk = Desk(width, height)
desk.shuffle()
yield desk
TEST_CASES = list(puzzle15_test_cases(100))
And a function that times the solution of all the cases, using the timeit
module:
from timeit import timeit
def puzzle15_time():
"""Return the time taken to solve the puzzles in the TEST_CASES list."""
return timeit('[Puzzle15(desk).get_solution() for desk in TEST_CASES]',
'from __main__ import TEST_CASES, Puzzle15; gc.enable()',
number = 1)
(See here for an explanation of gc.enable()
.) It takes more than three minutes on my machine to solve all 100 test cases:
>>> puzzle15_time()
184.2024281024933
2. Use sets for fast lookup
The first obvious time sink is the open and closed sets. You implement these using lists, but Python's list
does not have an efficient membership test: to implement neighbor in closed_set
Python has to scan all the way along the list, comparing neighbor
with each item in turn until it finds one that matches. By using set
objects instead, we get a constant-time membership test. So change:
closed_set = []
openset = [start]
to:
closed_set = set()
openset = set([start])
and use the set.add()
method instead of list.append()
. This change gives us an immediate 28% speedup:
>>> puzzle15_time()
132.80268001556396
3. Use the power of the dictionary
The next obvious problem is lowest_score_element
. You have a double loop:
for elem in openset:
if (elem in score.keys()):
So for each position in openset
, you construct a fresh list containing the keys of the dictionary score
, and then you look up the position in the list (which, as explained above, might require comparing the position to every item in the list). You could just write:
for elem in openset:
if elem in score:
so that the membership test uses the fast dictionary lookup.
But you don't even need to do this, because Python already has a function min
for finding the smallest element in a collection. So I would implement the method like this:
def lowest_score_element(self, openset, score):
return min(openset, key=score.get)
And that yields a very dramatic improvement:
>>> puzzle15_time()
3.443160057067871
4. Make positions immutable
Instances of the Desk
class represent positions in the 15 puzzle. You need to look up these positions in sets and dictionaries, and that means they need to be hashable. But if you read the documentation for the special __hash__
method, you'll see that it says:
If a class defines mutable objects and implements a
__cmp__()
or__eq__()
method, it should not implement__hash__()
, since hashable collection implementations require that a object’s hash value is immutable (if the object’s hash value changes, it will be in the wrong hash bucket).
Your Desk
objects are currently mutable — they can be changed by the swap
and set_element
and shuffle
methods. This makes them unsuitable for storing in sets or using as dictionary keys. So let's make them immutable instead. And at the same time, make the following improvements:
Use the more understandable name
Position
instead ofDesk
.Write docstrings for the class and its methods.
Represent the matrix as a single tuple instead of a list-of-lists. This means that a cell can be fetched in a single lookup instead of two lookups.
Remove the
get_element
method (just look up cells directly, saving a method call) and remove theset_element
method (it's not needed now that the class is immutable).Represent the blank item as 15 rather than 0, to avoid the special case in
heuristic_cost
.Remember where the blank element is, so that neighbours can be generated without having to search the matrix to find the blank.
Generate the neighbours instead of returning them as a list.
Speed up the
__hash__
by usingtuple.__hash__
instead of carrying out an expensive conversion to a string each time we want to compute the hash.Rename
swap
andshuffle
toswapped
andshuffled
since they now return new objects instead of modifying the old object in place.Pass the number of swaps as a parameter to the
shuffled
method.Avoid calling
self.neighbors
twice in theshuffle
method, by usingrandom.choice
.Use
divmod
instead of separately computing/
and%
.Divide the heuristic cost by 2 (because one swap changes the sum of distances by 2).
That results in the following code:
class Position(object):
"""A position in the 15 puzzle (or a variant thereof)."""
def __init__(self, width, height, matrix=None, blank=None):
assert(width > 1 and height > 1)
self.width = width
self.height = height
self.cells = width * height
if matrix is None:
matrix = tuple(range(self.cells))
blank = self.cells - 1
assert(len(matrix) == self.cells)
assert(0 <= blank < self.cells)
self.matrix = matrix
self.blank = blank
def __repr__(self):
return 'Position({0.width}, {0.height}, {0.matrix})'.format(self)
def __eq__(self, other):
return self.matrix == other.matrix
def __ne__(self, other):
return self.matrix != other.matrix
def __hash__(self):
return hash(self.matrix)
def shuffled(self, swaps=20):
"""Return a new position after making 'swaps' swaps."""
result = self
for _ in range(swaps):
result = random.choice(list(result.neighbors()))
return result
def heuristic_cost(self):
"""Return an admissible estimate of the number of swaps needed to
solve the puzzle from this position.
"""
total = 0
for m, n in enumerate(self.matrix):
mx, my = divmod(m, self.width)
nx, ny = divmod(n, self.width)
total += abs(mx - nx) + abs(my - ny)
return total // 2
def swapped(self, c):
"""Return a new position with cell 'c' swapped with the blank."""
assert(c != self.blank)
i, j = sorted([c, self.blank])
return Position(self.width, self.height,
self.matrix[:i] + (self.matrix[j],)
+ self.matrix[i+1:j] + (self.matrix[i],)
+ self.matrix[j+1:], c)
def neighbors(self):
"""Generate the neighbors to this position, namely the positions
reachable from this position via a single swap.
"""
zy, zx = divmod(self.blank, self.width)
if zx > 0:
yield self.swapped(self.blank - 1)
if zx < self.width - 1:
yield self.swapped(self.blank + 1)
if zy > 0:
yield self.swapped(self.blank - self.width)
if zy < self.height - 1:
yield self.swapped(self.blank + self.width)
(We also have to make a couple of minor changes to Puzzle15
and puzzle15_time
for this to work, but it should be clear what's required.)
This yields another ×ばつ speedup:
>>> puzzle15_time()
0.42876315116882324
5. Find the minimum using a heap
Python's min
function still has to look at every element in the open set, taking O(n) time if there are n elements in the open set. If we kept a copy of the open set in a heap data structure, then we'd be able to find the minimum element in O(log n).
Python's built-in heapq
module provides a way to do this, and moreover, by storing the f-score, g-score and parent position in the heap, we can get rid of the dictionaries in which you currently store these values.
This answer is getting rather long, so I'll just give you the revised code and you can figure out how it works for yourself!
import heapq
class NotFoundError(Exception): pass
def puzzle15_solve(start):
goal = Position(start.width, start.height)
closed_set = set()
# Heap items are lists [F-score, G-score, position, parent data]
start_data = [start.heuristic_cost(), 0, start, None]
# open_set and open_heap always contain the same positions.
# open_set maps each position to the corresponding heap item.
open_set = {start: start_data}
open_heap = [start_data]
while open_heap:
current_data = heapq.heappop(open_heap)
f_current, g_current, current, parent_data = current_data
if current == goal:
def path(data):
while data:
yield data[2]
data = data[3]
return list(path(current_data))[::-1]
del open_set[current]
closed_set.add(current)
for neighbor in current.neighbors():
if neighbor in closed_set:
continue
g_neighbor = g_current + 1
f_neighbor = g_neighbor + neighbor.heuristic_cost()
neighbor_data = [f_neighbor, g_neighbor, neighbor, current_data]
if neighbor not in open_set:
open_set[neighbor] = neighbor_data
heapq.heappush(open_heap, neighbor_data)
else:
old_neighbor_data = open_set[neighbor]
if neighbor_data < old_neighbor_data:
old_neighbor_data[:] = neighbor_data
heapq.heapify(open_heap)
raise NotFoundError("No solution for {}".format(start))
(Note that there's no persistent data here, so there's no need to make this function into a class.)
That yields another ×ばつ speedup or so:
>>> puzzle15_time()
0.2613048553466797
6. A different algorithm
I've managed to make your code about 700 times faster (on this particular set of 20-shuffle test cases), but you'll find that if you shuffle the start position many times, the search often still does not complete in a reasonable amount of time.
This is due to a combination of (i) the size of the search space (10,461,394,944,000 positions); and (ii) the A* algorithm, which keeps searching until it has found the best (shortest) solution.
But if you don't care about getting the best solution, but are willing to accept any solution, then you could change the algorithm to ignore the g-score. Replace:
f_neighbor = g_neighbor + neighbor.heuristic_cost()
with:
f_neighbor = neighbor.heuristic_cost()
and now you can solve arbitrarily shuffled positions in a couple of seconds:
>>> timeit(lambda:puzzle15_solve(Position(4, 4).shuffled(1000)), number=1)
2.0673770904541016
This change turns A* into a form of best-first search.
-
\$\begingroup\$ You have done an amazing job,congratulations. It seems though that the best-first search algorithm needs 14x time to solve the TEST_CASES. Is this expected sometimes, or is it a blunder of mine? \$\endgroup\$NameOfTheRose– NameOfTheRose2017年06月05日 16:06:08 +00:00Commented Jun 5, 2017 at 16:06
-
\$\begingroup\$ @NameOfTheRose: Hard to know what happened without seeing exactly what you did, but I find that the best-first search is more than 140 times faster on a set of 20-shuffle test cases. \$\endgroup\$Gareth Rees– Gareth Rees2017年06月05日 18:21:01 +00:00Commented Jun 5, 2017 at 18:21
-
\$\begingroup\$ There is one particular puzzle that takes too long:
Position(4,4,(1, 2, 3, 7, 4, 0, 5, 6, 15, 9, 10, 11, 8, 12, 13, 14),8)
. But in general the best-first-search is roughly only twice as fast. Of course it solves puzzles (like the 1000 shuffle above in 1.4s) that A* search fails to solve. The only change I did was to change thef_neighbor
calculation as shown above. Thank you. \$\endgroup\$NameOfTheRose– NameOfTheRose2017年06月06日 12:16:30 +00:00Commented Jun 6, 2017 at 12:16 -
\$\begingroup\$ For python3 compatibility, one has to add
__lt__
method to thePosition class
def __lt__(self, other): return self.heuristic_cost() < other.heuristic_cost()
\$\endgroup\$NameOfTheRose– NameOfTheRose2017年07月07日 12:48:28 +00:00Commented Jul 7, 2017 at 12:48
Explore related questions
See similar questions with these tags.