The puzzle goes like this: in a rectangular 2D grid there are empty spaces (.
), exactly one starting point (S
, s
) and obstacles (denoted below by X
's). The objective of the puzzle is to find a path starting at the starting point and going through each empty space exactly once (a Hamiltonian path). You can't, of course cross the obstacles. You can move horizontally and vertically. A typical puzzle would look like this:
......
.SX...
...X..
.....X
......
XX....
......
And its solution:
11 12 13 14 17 18
10 1 X 15 16 19
9 2 3 X 21 20
8 5 4 23 22 X
7 6 25 24 29 30
X X 26 27 28 31
37 36 35 34 33 32
I wrote a solver in Python 3 (I later learned that this algorithm is actually a simple DFS). It solves the puzzle above in ~0.9 s, which is very good, but I was wondering if I can perhaps optimize it somehow.
import time
EMPTY_SPACE_SYMBOLS = ['.']
STARTING_POINT_SYMBOLS = ['S', 's']
OBSTACLE_SYMBOL = 'X'
PUZZLE_PATH = "grid.txt"
DIRS = [(-1, 0), (1, 0), (0, 1), (0, -1)]
start_time = time.time()
grid = open(PUZZLE_PATH).read().splitlines()
H = len(grid)
W = len(grid[0])
assert all(len(row) == W for row in grid), "Grid not rectangular"
def print_solution(coords):
result_grid = [[OBSTACLE_SYMBOL for _ in range(W)] for _ in range(H)]
for i, (r, c) in enumerate(coords, start=1):
result_grid[r][c] = i
str_grid = [[str(item).ljust(3) for item in row] for row in result_grid]
print('\n'.join(' '.join(row) for row in str_grid))
def extend(path, legal_coords):
res = []
lx, ly = path[-1]
for dx, dy in DIRS:
new_coord = (lx + dx, ly + dy)
if new_coord in legal_coords and new_coord not in path:
res.append(path + [new_coord])
return res
start_pos = None
legal = set()
for r, row in enumerate(grid):
for c, item in enumerate(row):
if item in STARTING_POINT_SYMBOLS:
assert start_pos is None, "Multiple starting points"
start_pos = (r, c)
elif item in EMPTY_SPACE_SYMBOLS:
legal.add((r, c))
assert start_pos is not None, "No starting point"
TARGET_PATH_LEN = len(legal) + 1
paths = [[start_pos]]
found = False
number_of_solutions = 0
while paths:
cur_path = paths.pop()
if len(cur_path) == TARGET_PATH_LEN:
number_of_solutions += 1
if not found:
print_solution(cur_path)
print("Solution found in {} s".format(time.time() - start_time))
found = True
paths += extend(cur_path, legal)
print('Total number of solutions found: {} (took: {} s)'.format(number_of_solutions, time.time() - start_time))
-
\$\begingroup\$ I was thinking of using this code to test with. I wonder why when I make the grid 10x10 it doesn't calculate... \$\endgroup\$Joseph Langdon– Joseph Langdon2022年03月04日 06:24:46 +00:00Commented Mar 4, 2022 at 6:24
2 Answers 2
1. Review
Code is easier to test (and measure the performance) if it's organized into functions or classes. In this case you have persistent data (
W
,H
,grid
,legal
,start_pos
) that's shared between the various parts of the code, so a class would be the best way to go here.It would be simpler to write:
EMPTY_SPACE_SYMBOLS = '.' STARTING_POINT_SYMBOLS = 'Ss'
since Python can iterate over the characters in a string just as easily as over the items in a list.
Assertions should be reserved for detecting programming errors, and not used for reporting problems with the input. That's because assertions can be disabled via the
-O
option to the Python interpreter, or thePYTHONOPTIMIZE
environment variable. For the errors in this code, you should raise exceptions.It's more general to have a method
format_solution
that returns a string, rather than a functionprint_solution
that prints it. That way the caller has the option of deciding what to do with it.Instead of:
[OBSTACLE_SYMBOL for _ in range(self.w)]
write:
[OBSTACLE_SYMBOL] * self.w
Instead of building
str_grid
fromresult_grid
and then formatting the former, just formatresult_grid
directly:'\n'.join(' '.join(str(item).ljust(3) for item in row) for row in result_grid)
The number "3" only works for small enough puzzles. To make the formatting work in the general case, it should be one more than the maximum number of digits in any number in the path, that is:
len(str(len(path) + 1)) + 1
extend
is only called from one place in the code, so it would make sense to inline it at that point.Coordinates are
r
andc
in some places in the code, andx
andy
in other places. It would be better to be consistent.
2. Revised code
import time
EMPTY_SPACE_SYMBOLS = '.'
STARTING_POINT_SYMBOLS = 'Ss'
OBSTACLE_SYMBOL = 'X'
DIRS = [(-1, 0), (1, 0), (0, 1), (0, -1)]
class HamiltonSolver:
"""Solver for a Hamilton Path problem."""
def __init__(self, grid):
"""Initialize the HamiltonSolver instance from a grid, which must be a
list of strings, one for each row of the grid.
"""
self.grid = grid
self.h = h = len(grid)
self.w = w = len(grid[0])
if any(len(row) != w for row in grid):
raise ValueError("Grid is not rectangular")
self.start = None
self.legal = set()
for r, row in enumerate(grid):
for c, item in enumerate(row):
if item in STARTING_POINT_SYMBOLS:
if self.start is not None:
raise ValueError("Multiple starting points")
self.start = (r, c)
elif item in EMPTY_SPACE_SYMBOLS:
self.legal.add((r, c))
if self.start is None:
raise ValueError("No starting point")
def format_solution(self, path):
"""Format a path as a string."""
grid = [[OBSTACLE_SYMBOL] * self.w for _ in range(self.h)]
for i, (r, c) in enumerate(path, start=1):
grid[r][c] = i
w = len(str(len(path) + 1)) + 1
return '\n'.join(''.join(str(item).ljust(w) for item in row)
for row in grid)
def solve(self):
"""Generate solutions as lists of coordinates."""
target_path_len = len(self.legal) + 1
paths = [[self.start]]
while paths:
path = paths.pop()
if len(path) == target_path_len:
yield path
r, c = path[-1]
for dr, dc in DIRS:
new_coord = r + dr, c + dc
if new_coord in self.legal and new_coord not in path:
paths.append(path + [new_coord])
PUZZLE_GRID = '''
......
.SX...
...X..
.....X
......
XX....
......
'''.split()
def main():
start_time = time.time()
n_solutions = 0
puzzle = HamiltonSolver(PUZZLE_GRID)
for solution in puzzle.solve():
if n_solutions == 0:
print(puzzle.format_solution(solution))
print("First solution found in {} s"
.format(time.time() - start_time))
n_solutions += 1
print("{} solution{} found in {} s"
.format(n_solutions, '' if n_solutions == 1 else 's',
time.time() - start_time))
if __name__ == '__main__':
main()
3. Speeding it up
Each time
extend
is called, the list of paths under consideration is extended by up to four new paths, each of which is a copy of the current path plus one new grid coordinate. As the search progresses, this results in many paths being created and stored (up to 20 paths in the example from the post).But in a depth-first search it should be only be necessary to remember a single path at a time, together with the current position in the iteration over the four directions. (This is the "stack of iterators" pattern.)
The code checks to see if the new coordinate is available by calling
new_coord in self.legal and new_coord not in path
. Butpath
is a list, and testing for membership in a list is not efficient (Python has to compare the new coordinate with every item in the list). For efficient membership testing you need a set.In this case the code already has a set of legal coordinates. So we can remove a coordinate from this set when we add it to the path, and add it back to the set when we backtrack. This means we only need one membership test, and it also means that the path is complete when the set of legal coordinates is empty.
We can speed up the inner loop by caching method lookups like
path.append
in local variables.
The revised solve
method looks like this:
def solve(self):
"""Generate solutions as lists of coordinates."""
path = [self.start]
dirs = [iter(DIRS)]
# Cache attribute lookups in local variables
path_append = path.append
path_pop = path.pop
legal = self.legal
legal_add = legal.add
legal_remove = legal.remove
dirs_append = dirs.append
dirs_pop = dirs.pop
while path:
r, c = path[-1]
for dr, dc in dirs[-1]:
new_coord = r + dr, c + dc
if new_coord in legal:
path_append(new_coord)
legal_remove(new_coord)
dirs_append(iter(DIRS))
if not legal:
yield path
break
else:
legal_add(path_pop())
dirs_pop()
On the example from the post this is more than 21⁄2 times as fast as the original code.
Update You asked in comments how this works. Well, the list path
contains the current path reached in the search. The list dirs
is the same length as path
, and each element is an iterator over the list DIRS
. Each iterator remembers where we got to in the list of directions at that point in the search, so that when we backtrack we can pop the dirs
list and then continue iterating from where we left off.
The control flow is slightly tricky. When the path is extended it reaches the break
statement, which breaks out of the for
loop (bypassing the else
) and so goes back to the top of the while
loop, picking up the new coordinates that we just added to the path. But when no further progress can be made, the for
loop runs out of directions and enters the else
clause. This causes path
and dirs
to be popped and then the while
loop picks up from the previous point reached in the search.
I hope that's clear. If you're still not quite sure, step through the code in the Python debugger so that you can see exactly what it's doing.
-
\$\begingroup\$ This revised
solve
method is brilliant, but could you maybe explain how thefor-else
loop works in this case? Thank you for your help. \$\endgroup\$shooqie– shooqie2016年09月27日 08:32:20 +00:00Commented Sep 27, 2016 at 8:32 -
\$\begingroup\$ @shooqie: See updated answer. \$\endgroup\$Gareth Rees– Gareth Rees2016年09月27日 09:43:06 +00:00Commented Sep 27, 2016 at 9:43
I suggest starting from the issues of organization and naming in order to make optimization simpler:
Organization
The organization of your program could have been cleaner: function definition and top-level blocks of code alternate continuously.
In general all functions are defined, then they are called one by one by the main
function that puts everything together.
This organization makes reading the code easier by compartmentalizing it into almost independent blocks.
Some practical examples:
start_pos = None
legal = set()
for r, row in enumerate(grid):
for c, item in enumerate(row):
if item in STARTING_POINT_SYMBOLS:
assert start_pos is None, "Multiple starting points"
start_pos = (r, c)
elif item in EMPTY_SPACE_SYMBOLS:
legal.add((r, c))
This piece of code clearly finds the start_pos
and the legal
(coordinates) so it can be a function:
def find_start_and_legals(grid):
start_pos = None
legal = set()
for r, row in enumerate(grid):
for c, item in enumerate(row):
if item in STARTING_POINT_SYMBOLS:
assert start_pos is None, "Multiple starting points"
start_pos = (r, c)
elif item in EMPTY_SPACE_SYMBOLS:
legal.add((r, c))
return start_pos, legal
This other block of code finds the hamilton paths, so it can be a function too:
while paths:
cur_path = paths.pop()
if len(cur_path) == TARGET_PATH_LEN:
number_of_solutions += 1
if not found:
print_solution(cur_path)
print("Solution found in {} s".format(time.time() - start_time))
found = True
paths += extend(cur_path, legal)
So it can become a function too:
def find_hamilton_paths(grid):
start_pos, legal = find_start_and_legals(grid)
TARGET_PATH_LEN = len(legal) + 1
paths = [[start_pos]]
while paths:
cur_path = paths.pop()
if len(cur_path) == TARGET_PATH_LEN:
yield cur_path
paths += extend(cur_path, legal)
I also decided to remove the print
and time
statements to better highlight the algorithm.
Timing your program with time.time
is not appropriate for larger projects and does not offer benefits in smaller ones neither.
Use time python3 hamilton.py
to get over-all runtime and python -m cProfile hamilton.py
to get detailed function-by-function profiling.
Separation of code into functions also helps because the profiler clearly shows that some functions take so little time that any time spent optimizing them would be wasted:
2 0.000 0.000 0.000 0.000 hamilton.py:16(find_start_and_legals)
1 0.000 0.000 0.000 0.000 hamilton.py:9(print_solution)
Profiling works better with smaller functions.
I include my modularized function for your and future optimizers benefit:
EMPTY_SPACE_SYMBOLS = ['.']
STARTING_POINT_SYMBOLS = ['S', 's']
OBSTACLE_SYMBOL = 'X'
PUZZLE_PATH = "grid.txt"
DIRS = [(-1, 0), (1, 0), (0, 1), (0, -1)]
def print_solution(coords, WIDTH, HEIGHT):
result_grid = [[OBSTACLE_SYMBOL for _ in range(WIDTH)] for _ in range(HEIGHT)]
for i, (r, c) in enumerate(coords, start=1):
result_grid[r][c] = i
str_grid = [[str(item).ljust(3) for item in row] for row in result_grid]
print('\n'.join(' '.join(row) for row in str_grid))
def find_start_and_legals(grid):
start_pos = None
legal = set()
for r, row in enumerate(grid):
for c, item in enumerate(row):
if item in STARTING_POINT_SYMBOLS:
assert start_pos is None, "Multiple starting points"
start_pos = (r, c)
elif item in EMPTY_SPACE_SYMBOLS:
legal.add((r, c))
return start_pos, legal
def extend(path, legal_coords):
res = []
lx, ly = path[-1]
for dx, dy in DIRS:
new_coord = (lx + dx, ly + dy)
if new_coord in legal_coords and new_coord not in path:
res.append(path + [new_coord])
return res
def find_hamilton_paths(grid):
start_pos, legal = find_start_and_legals(grid)
TARGET_PATH_LEN = len(legal) + 1
paths = [[start_pos]]
while paths:
cur_path = paths.pop()
if len(cur_path) == TARGET_PATH_LEN:
yield cur_path
paths += extend(cur_path, legal)
def read_grid(filename):
with open(filename) as f:
grid = f.read().splitlines()
HEIGHT, WIDTH = len(grid), len(grid[0])
assert all(len(row) == WIDTH for row in grid), "Grid not rectangular"
return grid, HEIGHT, WIDTH
def main():
grid, HEIGHT, WIDTH = read_grid(PUZZLE_PATH)
start_pos, legal = find_start_and_legals(grid)
for solution in find_hamilton_paths(grid):
print_solution(solution, WIDTH, HEIGHT)
if __name__ == "__main__":
main()
Naming
To optimize your program people must understand it very well, so naming becomes even more important because it helps the reader in understanding difficult algorithms:
r -> row
c -> column
start_pos
->start_position
legal
->legal_coordinates
res
->result
ornew_paths
lx, ly
->???
I did not understand this.dx, dy
->delta_x, delta_y
H, W
->HEIGHT, WIDTH
Time wasted deducing full names from abbreviations is time that could be spend in understanding the algorithm and optimizing it.
Code is written once and read many times, please optimize for read-time, not write-time.
-
\$\begingroup\$ I for one prefer
start_pos
, and would usecol
instead ofcolumn
, because I don't like to read big words. \$\endgroup\$Dair– Dair2016年09月27日 01:19:51 +00:00Commented Sep 27, 2016 at 1:19
Explore related questions
See similar questions with these tags.