A Sudoku solver that works recursively. I'd appreciate your comments about coding style, structure and how to improve it. Thank you very much for your time.
Code structure
The Solver works by accepting a string of 81 digits for the Sudoku puzzle input. Zeros are taken as empty cells. It parses it into a 9x9 Numpy array.
The get_candidates
function creates lists of possible digits to fill each cell following Sudoku's rules (no repeating 1-9 digit along rows, columns and 3x3 sub-grids).
The main solver function is solve
. First, it discards wrong candidates with the filter-candidates
function. "Wrong candidates" are those that when filled to a empty cell, led to another cell having no more candidates elsewhere on the Sudoku grid.
After filtering candidates, fill_singles
is called to fill empty cells that have only one remaining candidate. If this process leads to a completely filled Sudoku grid, it's returned as a solution. There's a clause to return None
which is used to backtrack changes by the make_guess
function. This function will fill the next empty cell with the least quantity of candidates with one of its candidates, a "guess" value. It then recursively calls solve
to either find a solution or reach a no-solution grid (in which case solve
returns None
and the last guess changes are reverted).
from copy import deepcopy
import numpy as np
def create_grid(puzzle_str: str) -> np.ndarray:
"""Create a 9x9 Sudoku grid from a string of digits"""
# Deleting whitespaces and newlines (\n)
lines = puzzle_str.replace(' ','').replace('\n','')
digits = list(map(int, lines))
# Turning it to a 9x9 numpy array
grid = np.array(digits).reshape(9,9)
return grid
def get_subgrids(grid: np.ndarray) -> np.ndarray:
"""Divide the input grid into 9 3x3 sub-grids"""
subgrids = []
for box_i in range(3):
for box_j in range(3):
subgrid = []
for i in range(3):
for j in range(3):
subgrid.append(grid[3*box_i + i][3*box_j + j])
subgrids.append(subgrid)
return np.array(subgrids)
def get_candidates(grid : np.ndarray) -> list:
"""Get a list of candidates to fill empty cells of the input grid"""
def subgrid_index(i, j):
return (i//3) * 3 + j // 3
subgrids = get_subgrids(grid)
grid_candidates = []
for i in range(9):
row_candidates = []
for j in range(9):
# Row, column and subgrid digits
row = set(grid[i])
col = set(grid[:, j])
sub = set(subgrids[subgrid_index(i, j)])
common = row | col | sub
candidates = set(range(10)) - common
# If the case is filled take its value as the only candidate
if not grid[i][j]:
row_candidates.append(list(candidates))
else:
row_candidates.append([grid[i][j]])
grid_candidates.append(row_candidates)
return grid_candidates
def is_valid_grid(grid : np.ndarray) -> bool:
"""Verify the input grid has a possible solution"""
candidates = get_candidates(grid)
for i in range(9):
for j in range(9):
if len(candidates[i][j]) == 0:
return False
return True
def is_solution(grid : np.ndarray) -> bool:
"""Verify if the input grid is a solution"""
if np.all(np.sum(grid, axis=1) == 45) and \
np.all(np.sum(grid, axis=0) == 45) and \
np.all(np.sum(get_subgrids(grid), axis=1) == 45):
return True
return False
def filter_candidates(grid : np.ndarray) -> list:
"""Filter input grid's list of candidates"""
test_grid = grid.copy()
candidates = get_candidates(grid)
filtered_candidates = deepcopy(candidates)
for i in range(9):
for j in range(9):
# Check for empty cells
if grid[i][j] == 0:
for candidate in candidates[i][j]:
# Use test candidate
test_grid[i][j] = candidate
# Remove candidate if it produces an invalid grid
if not is_valid_grid(fill_singles(test_grid)):
filtered_candidates[i][j].remove(candidate)
# Revert changes
test_grid[i][j] = 0
return filtered_candidates
def merge(candidates_1 : list, candidates_2 : list) -> list:
"""Take shortest candidate list from inputs for each cell"""
candidates_min = []
for i in range(9):
row = []
for j in range(9):
if len(candidates_1[i][j]) < len(candidates_2[i][j]):
row.append(candidates_1[i][j][:])
else:
row.append(candidates_2[i][j][:])
candidates_min.append(row)
return candidates_min
def fill_singles(grid : np.ndarray, candidates=None) -> np.ndarray:
"""Fill input grid's cells with single candidates"""
grid = grid.copy()
if not candidates:
candidates = get_candidates(grid)
any_fill = True
while any_fill:
any_fill = False
for i in range(9):
for j in range(9):
if len(candidates[i][j]) == 1 and grid[i][j] == 0:
grid[i][j] = candidates[i][j][0]
candidates = merge(get_candidates(grid), candidates)
any_fill = True
return grid
def make_guess(grid : np.ndarray, candidates=None) -> np.ndarray:
"""Fill next empty cell with least candidates with first candidate"""
grid = grid.copy()
if not candidates:
candidates = get_candidates(grid)
# Getting the shortest number of candidates > 1:
min_len = sorted(list(set(map(
len, np.array(candidates).reshape(1,81)[0]))))[1]
for i in range(9):
for j in range(9):
if len(candidates[i][j]) == min_len:
for guess in candidates[i][j]:
grid[i][j] = guess
solution = solve(grid)
if solution is not None:
return solution
# Discarding a wrong guess
grid[i][j] = 0
def solve(grid : np.ndarray) -> np.ndarray:
"""Recursively find a solution filtering candidates and guessing values"""
candidates = filter_candidates(grid)
grid = fill_singles(grid, candidates)
if is_solution(grid):
return grid
if not is_valid_grid(grid):
return None
return make_guess(grid, candidates)
# # Example usage
# puzzle = """100920000
# 524010000
# 000000070
# 050008102
# 000000000
# 402700090
# 060000000
# 000030945
# 000071006"""
# grid = create_grid(puzzle)
# solve(grid)
```
2 Answers 2
I was able to improve the performance of the program by about 900% without understanding or changing much of the algorithm in about an hour. Here's what I did:
First of all, you need a benchmark. It's very simple, just time your program
start = time.time()
solve(grid)
print(time.time()-start)
On my computer, it took about 4.5 seconds. This is our baseline.
The next thing is to profile. The tool I chose is VizTracer, which is developed by myself :) https://github.com/gaogaotiantian/viztracer
VizTracer will generate an HTML report(or json that could be loaded by chrome:://tracing) of timeline of your code execution. It looks like this in your original version:
As you can tell, there are a lot of calls on there. The thing we need to do is to figure out what is the bottleneck here. The structure is not complicated, a lot of fill_singles
are called, and we need to zoom in to check what's in there.
It's very clear that get_candidates
is the function that caused most of the time in fill_singles
, which is occupying most of the timeline. So that's the function we want to take a look at first.
def get_candidates(grid : np.ndarray) -> list:
"""Get a list of candidates to fill empty cells of the input grid"""
def subgrid_index(i, j):
return (i//3) * 3 + j // 3
subgrids = get_subgrids(grid)
grid_candidates = []
for i in range(9):
row_candidates = []
for j in range(9):
# Row, column and subgrid digits
row = set(grid[i])
col = set(grid[:, j])
sub = set(subgrids[subgrid_index(i, j)])
common = row | col | sub
candidates = set(range(10)) - common
# If the case is filled take its value as the only candidate
if not grid[i][j]:
row_candidates.append(list(candidates))
else:
row_candidates.append([grid[i][j]])
grid_candidates.append(row_candidates)
return grid_candidates
The thing that caught my eyes first was the end of your nested for loop. You checked whether grid[i][j]
is filled. If it is, then that's the only candidate. However, if it's filled, then it has nothing to do with candidates
, which you computed very hard in your nested for loop.
So the first thing I did was moving the check to the beginning of the for loop.
for i in range(9):
row_candidates = []
for j in range(9):
if grid[i][j]:
row_candidates.append([grid[i][j]])
continue
# Row, column and subgrid digits
row = set(grid[i])
col = set(grid[:, j])
sub = set(subgrids[subgrid_index(i, j)])
common = row | col | sub
candidates = set(range(10)) - common
row_candidates.append(list(candidates))
This optimization alone cut the running time in half, we are at about 2.3s now.
Then I noticed in your nested for loop, you are doing a lot of redundant set operations. Even row/col/sub only needs to be computed 9 times, you are computing it 81 times, which is pretty bad. So I moved the computation out of the for loop.
def get_candidates(grid : np.ndarray) -> list:
"""Get a list of candidates to fill empty cells of the input grid"""
def subgrid_index(i, j):
return (i//3) * 3 + j // 3
subgrids = get_subgrids(grid)
grid_candidates = []
row_sets = [set(grid[i]) for i in range(9)]
col_sets = [set(grid[:, j]) for j in range(9)]
subgrid_sets = [set(subgrids[i]) for i in range(9)]
total_sets = set(range(10))
for i in range(9):
row_candidates = []
for j in range(9):
if grid[i][j]:
row_candidates.append([grid[i][j]])
continue
# Row, column and subgrid digits
row = row_sets[i]
col = col_sets[j]
sub = subgrid_sets[subgrid_index(i, j)]
common = row | col | sub
candidates = total_sets - common
# If the case is filled take its value as the only candidate
row_candidates.append(list(candidates))
grid_candidates.append(row_candidates)
return grid_candidates
This cut the running time to about 1.5s. Notice that, I have not try to understand your algorithm yet. Thing only thing I did was to use VizTracer to find the function that needs to be optimized and do same-logic transform. I improved performance by about 300% in like 15 minutes.
To this point, the overhead of VizTracer on WSL is significant, so I turned off the C function trace. Only Python functions were left and the overhead was about 10%.
Now the get_candidates
was improved(although it can be done better), we need to take a bigger picture of this. What I can observe from VizTracer's result was that fill_singles
called get_candidates
very frequently, just too many calls. (This is something that's hard to notice on cProfiler)
So the next step was to figure out if we can make fill_singles
call get_candidates
less often. Here it requires some level of algorithm understanding.
while any_fill:
any_fill = False
for i in range(9):
for j in range(9):
if len(candidates[i][j]) == 1 and grid[i][j] == 0:
grid[i][j] = candidates[i][j][0]
candidates = merge(get_candidates(grid), candidates)
any_fill = True
It looks like here you tried to fill in one blank with only one candidate, and recalculate the candidates of the whole grid, then find the next blank with one candidate. This is a valid method, but this caused too many calls to get_candidates
. If you think about it, when we fill in a blank with a number n
, all the other blanks with only one candidate that's not n
won't be affected. So during one pass of the grid, we could actually try to fill more blanks in, as long as we do not fill in the same number twice. This way, we can call get_candidates
less often, which is a huge time consumer. I used a set to do this.
filled_number = set()
for i in range(9):
for j in range(9):
if len(candidates[i][j]) == 1 and grid[i][j] == 0 and candidates[i][j][0] not in filled_number:
grid[i][j] = candidates[i][j][0]
filled_number.add(candidates[i][j][0])
any_fill = True
candidates = merge(get_candidates(grid), candidates)
This brought the running time to 0.9s.
Then I looked at the VizTracer report, I realized fill_singles
is almost always called by filter_candidates
and the only thing filter_candidates
is interested in, is whether fill_singles
returns a valid grid. This is an information we might know early, as long as fill_singles
finds a position with no candidates. If we return early, we don't need to calculate get_candidates
that many times.
So I changed the code structure a little bit, made fill_singles
return None
if it can't find a valid grid.
Finally I was able to make the run time to 0.5s, which is 900% faster than the original version.
It was actually a fun adventure because I was testing my project VizTracer and tried to figure out if it was helpful to locate the time consuming part. It worked well :)
-
\$\begingroup\$ That's amazing! Your modifications are quite simple but they vastly improve performance. I'll perform try performing similar analyses in the future. Congratulations on VizTracer! Great tool. I'll use it in the future. \$\endgroup\$fabrizzio_gz– fabrizzio_gz2020年08月25日 03:43:50 +00:00Commented Aug 25, 2020 at 3:43
-
1\$\begingroup\$ Thanks! I was hoping that VizTracer can help people on problems like this. My boss once told me, never optimize without profiling. A good profiler is indeed very helpful to improve the performance :) \$\endgroup\$minker– minker2020年08月25日 05:17:47 +00:00Commented Aug 25, 2020 at 5:17
Numpyification
get_subgrids
essentially rearranges a numpy array with a minimum of numpy. It could be done with numpy itself, for example:
def get_subgrids(grid: np.ndarray) -> np.ndarray:
"""Divide the input grid into 9 3x3 sub-grids"""
swapped = np.swapaxes(np.reshape(grid, (3, 3, 3, 3)), 1, 2)
return np.reshape(swapped, (9, 9))
The downside I suppose is that swapping the middle two axes of a 4D array is a bit mind-bending.
Performance
Almost all time is spent in get_candidates
. I think the reasons for that are mainly:
- It gets called too often. For example, after filling in a cell (such as in
fill_singles
), rather than recompute the candidates from scratch, it would be faster to merely remove the new value from the candidates in the same row/col/house. - If a cell is filled, the list of candidates is just the filled-in value, but the expensive set computation is done anyway. That's easy to avoid just by moving those statement inside the
if
.
Algorithmic performance
This solver only makes use of Naked Singles as a "propagation technique", adding Hidden Singles is in my experience a very large step towards an efficient solver.
-
\$\begingroup\$ Thank you very much for your review. I couldn't find of way to numpify the sub-grids and your recommendation is spot on. One question, how did you find that the
get_candidates
function is the most used? Was it by simple checking of the code or did you use any tool? Thanks again. \$\endgroup\$fabrizzio_gz– fabrizzio_gz2020年08月23日 22:19:45 +00:00Commented Aug 23, 2020 at 22:19 -
1\$\begingroup\$ @fabrizzio_gz I used
%%prun
in a jupyter notebook, if you're running the code as standalone script you could use one of these \$\endgroup\$user555045– user5550452020年08月23日 22:26:06 +00:00Commented Aug 23, 2020 at 22:26 -
\$\begingroup\$ Great to know. Thanks a lot. \$\endgroup\$fabrizzio_gz– fabrizzio_gz2020年08月24日 03:22:01 +00:00Commented Aug 24, 2020 at 3:22