I wrote a program which solves Pentomino puzzles by using Knuth's Algorithm X. My priorities were: readability, straightforwardness and brevity of the solution, so I didn't try to squeeze as much performance as possible by eliminating function calls in some places or similar tricks.
I am interested in:
The code review. How this code can be better?
Performance improvement. Do you see a way to increase performance? May be numpy can help here? At this moment the fastest version takes 1 hour 57 minutes on Intel i5-2500 to find all solutions of 5 x 12 board. The second version is slower by three times, but it is having less of code duplicity, so it is more beautiful in my opinion - what do you think?
Reducing a code duplicity of the faster version but without of large performance drop, as I have in my second version.
Usage:
5 x 12 board
rows = 5
cols = 12
board = [[0] * cols for _ in range(rows)]
solver = Solver(board, figures, rows, cols, 1)
solver.find_solutions()
Output (partial - the first solution only)
The solution No 1
12 12 5 5 5 5 8 8 7 7 7 7
9 12 12 12 5 6 6 8 7 11 3 3
9 9 9 2 6 6 10 8 8 11 3 3
1 9 1 2 6 10 10 10 11 11 11 3
1 1 1 2 2 2 10 4 4 4 4 4
####################################################################
8 x 8 board with 4 prefilled cells in the middle.
rows = 8
cols = 8
board = [[0] * cols for _ in range(rows)]
board[3][3] = '#'
board[3][4] = '#'
board[4][3] = '#'
board[4][4] = '#'
solver = Solver(board, figures, rows, cols, 1)
solver.find_solutions()
Output (partial)
1614 variants have tried
Time has elapsed: 0:00:05.104412
The solution No 1
12 12 4 4 4 4 4 11
6 12 12 12 5 11 11 11
6 6 5 5 5 5 10 11
8 6 6 # # 10 10 10
8 8 8 # # 9 10 7
2 1 8 1 9 9 9 7
2 1 1 1 3 3 9 7
2 2 2 3 3 3 7 7
#################################################################
The program
The slower version's methods are commented out, remove comments to test it (and comment out corresponding methods of the first version).
#!/usr/bin/python3
from time import time
from datetime import timedelta
figures = {
(
(1,1,1,1,1),
),
(
(1,1,1,1),
(1,0,0,0),
),
(
(1,1,1,1),
(0,1,0,0),
),
(
(1,1,1,0),
(0,0,1,1),
),
(
(1,1,1),
(1,0,1),
),
(
(1,1,1),
(0,1,1),
),
(
(1,1,1),
(1,0,0),
(1,0,0)
),
(
(1,1,0),
(0,1,1),
(0,0,1)
),
(
(1,0,0),
(1,1,1),
(0,0,1)
),
(
(0,1,0),
(1,1,1),
(1,0,0)
),
(
(0,1,0),
(1,1,1),
(0,1,0)
),
(
(1,1,1),
(0,1,0),
(0,1,0)
)
}
class Node():
def __init__(self, value):
self.value = value
self.up = None
self.down = None
self.left = None
self.right = None
self.row_head = None
self.col_head = None
class Linked_list_2D():
def __init__(self, width):
self.width = width
self.head = None
self.size = 0
def append(self, value):
new_node = Node(value)
if self.head is None:
self.head = left_neigh = right_neigh = up_neigh = down_neigh = new_node
elif self.size % self.width == 0:
up_neigh = self.head.up
down_neigh = self.head
left_neigh = right_neigh = new_node
else:
left_neigh = self.head.up.left
right_neigh = left_neigh.right
if left_neigh is left_neigh.up:
up_neigh = down_neigh = new_node
else:
up_neigh = left_neigh.up.right
down_neigh = up_neigh.down
new_node.up = up_neigh
new_node.down = down_neigh
new_node.left = left_neigh
new_node.right = right_neigh
# Every node has links to the first node of row and column
# These nodes are used as the starting point to deletion and insertion
new_node.row_head = right_neigh
new_node.col_head = down_neigh
up_neigh.down = new_node
down_neigh.up = new_node
right_neigh.left = new_node
left_neigh.right = new_node
self.size += 1
def print_list(self, separator=' '):
for row in self.traverse_node_line(self.head, "down"):
for col in self.traverse_node_line(row, "right"):
print(col.value, end=separator)
print()
def traverse_node_line(self, start_node, direction):
cur_node = start_node
while cur_node:
yield cur_node
cur_node = getattr(cur_node, direction)
if cur_node is start_node:
break
### First approach - a lot of code duplicity
def col_nonzero_nodes(self, node):
cur_node = node
while cur_node:
if cur_node.value and cur_node.row_head is not self.head:
yield cur_node
cur_node = cur_node.down
if cur_node is node:
break
def row_nonzero_nodes(self, node):
cur_node = node
while cur_node:
if cur_node.value and cur_node.col_head is not self.head:
yield cur_node
cur_node = cur_node.right
if cur_node is node:
break
def delete_row(self, node):
cur_node = node
while cur_node:
up_neigh = cur_node.up
down_neigh = cur_node.down
if cur_node is self.head:
self.head = down_neigh
if cur_node is down_neigh:
self.head = None
up_neigh.down = down_neigh
down_neigh.up = up_neigh
cur_node = cur_node.right
if cur_node is node:
break
def insert_row(self, node):
cur_node = node
while cur_node:
up_neigh = cur_node.up
down_neigh = cur_node.down
up_neigh.down = cur_node
down_neigh.up = cur_node
cur_node = cur_node.right
if cur_node is node:
break
def insert_col(self, node):
cur_node = node
while cur_node:
left_neigh = cur_node.left
right_neigh = cur_node.right
left_neigh.right = cur_node
right_neigh.left = cur_node
cur_node = cur_node.down
if cur_node is node:
break
def delete_col(self, node):
cur_node = node
while cur_node:
left_neigh = cur_node.left
right_neigh = cur_node.right
if cur_node is self.head:
self.head = right_neigh
if cur_node is right_neigh:
self.head = None
left_neigh.right = right_neigh
right_neigh.left = left_neigh
cur_node = cur_node.down
if cur_node is node:
break
### Second approach - moving the common parts of code to separate functions.
### Then these functions are used in needed places, instead of duplicating the same code
### every time. But the perfomance was dropped by three times, so I decided to not use this
### methods in the program.
###
# def delete_node(self, cur_node, a_node, a_neigh, b_node, b_neigh):
# if cur_node is self.head:
# self.head = b_node
# if cur_node is b_node:
# self.head = None
# setattr(a_node, a_neigh, b_node)
# setattr(b_node, b_neigh, a_node)
# self.size -= 1
#
# def insert_node(self, cur_node, a_node, a_neigh, b_node, b_neigh):
# setattr(a_node, a_neigh, cur_node)
# setattr(b_node, b_neigh, cur_node)
# self.size += 1
#
# def col_nonzero_nodes(self, node):
# for cur_node in self.traverse_node_line(node.col_head, "down"):
# if cur_node.value and cur_node.row_head is not self.head:
# yield cur_node
#
# def row_nonzero_nodes(self, node):
# for cur_node in self.traverse_node_line(node.row_head, "right"):
# if cur_node.value and cur_node.col_head is not self.head:
# yield cur_node
#
# def delete_row(self, node):
# for cur_node in self.traverse_node_line(node, "right"):
# self.delete_node(cur_node, cur_node.up, "down", cur_node.down, "up")
#
# def delete_col(self, node):
# for cur_node in self.traverse_node_line(node, "down"):
# self.delete_node(cur_node, cur_node.left, "right", cur_node.right, "left")
#
# def insert_row(self, node):
# for cur_node in self.traverse_node_line(node, "right"):
# self.insert_node(cur_node, cur_node.up, "down", cur_node.down, "up")
#
# def insert_col(self, node):
# for cur_node in self.traverse_node_line(node, "down"):
# self.insert_node(cur_node, cur_node.left, "right", cur_node.right, "left")
class Solver():
def __init__(self, board, figures, rows, cols, figures_naming_start):
self.rows = rows
self.cols = cols
self.fig_name_start = figures_naming_start
self.figures = figures
self.solutions = set()
self.llist = None
self.start_board = board
self.tried_variants_num = 0
def find_solutions(self):
named_figures = set(enumerate(self.figures, self.fig_name_start))
all_figure_postures = self.unique_figure_postures(named_figures)
self.llist = Linked_list_2D(self.rows * self.cols + 1)
pos_gen = self.generate_positions(all_figure_postures, self.rows, self.cols)
for line in pos_gen:
for val in line:
self.llist.append(val)
self.delete_filled_on_start_cells(self.llist)
self.starttime = time()
self.prevtime = self.starttime
self.dlx_alg(self.llist, self.start_board)
# Converts a one dimensional's element index to two dimensional's coordinates
def num_to_coords(self, num):
row = num // self.cols
col = num - row * self.cols
return row, col
def delete_filled_on_start_cells(self, llist):
for col_head_node in llist.row_nonzero_nodes(llist.head):
row, col = self.num_to_coords(col_head_node.value - 1)
if self.start_board[row][col]:
llist.delete_col(col_head_node)
def print_progress(self, message, interval):
new_time = time()
if (new_time - self.prevtime) >= interval:
print(message)
print(f"Time has elapsed: {timedelta(seconds=new_time - self.starttime)}")
self.prevtime = new_time
def check_solution_uniqueness(self, solution):
reflected_solution = self.reflect(solution)
for sol in [solution, reflected_solution]:
if sol in self.solutions:
return
for _ in range(3):
sol = self.rotate(sol)
if sol in self.solutions:
return
return 1
def dlx_alg(self, llist, board):
# If no rows left - all figures are used
if llist.head.down is llist.head:
self.print_progress(f"{self.tried_variants_num} variants have tried", 5.0)
self.tried_variants_num += 1
# If no columns left - all cells are filled, the solution is found.
if llist.head.right is llist.head:
solution = tuple(tuple(row) for row in board)
if self.check_solution_uniqueness(solution):
print(f"The solution No {len(self.solutions) + 1}")
self.print_board(solution)
self.solutions.add(solution)
return
# Search a column with a minimum of intersected rows
min_col, min_col_sum = self.find_min_col(llist, llist.head)
# The perfomance optimization - stops branch analyzing if empty columns appears
if min_col_sum == 0:
self.tried_variants_num += 1
return
intersected_rows = []
for node in llist.col_nonzero_nodes(min_col):
intersected_rows.append(node.row_head)
# Pick one row (the variant of figure) and try to solve puzzle with it
for selected_row in intersected_rows:
rows_to_restore = []
new_board = self.add_posture_to_board(selected_row, board)
# If some figure is used, any other variants (postures) of this figure
# could be discarded in this branch
for posture_num_node in llist.col_nonzero_nodes(llist.head):
if posture_num_node.value == selected_row.value:
rows_to_restore.append(posture_num_node)
llist.delete_row(posture_num_node)
cols_to_restore = []
for col_node in llist.row_nonzero_nodes(selected_row):
for row_node in llist.col_nonzero_nodes(col_node.col_head):
# Delete all rows which are using the same cell as the picked one,
# because only one figure can fill the specific cell
rows_to_restore.append(row_node.row_head)
llist.delete_row(row_node.row_head)
# Delete the columns the picked figure fill, they are not
# needed in this branch anymore
cols_to_restore.append(col_node.col_head)
llist.delete_col(col_node.col_head)
# Pass the shrinked llist and the board with the picked figure added
# to the next processing
self.dlx_alg(llist, new_board)
for row in rows_to_restore:
llist.insert_row(row)
for col in cols_to_restore:
llist.insert_col(col)
def find_min_col(self, llist, min_col):
min_col_sum = float("inf")
for col in llist.row_nonzero_nodes(llist.head):
tmp = sum(1 for item in llist.col_nonzero_nodes(col))
if tmp < min_col_sum:
min_col = col
min_col_sum = tmp
return min_col, min_col_sum
def add_posture_to_board(self, posture_row, prev_steps_result):
new_board = prev_steps_result.copy()
for node in self.llist.row_nonzero_nodes(posture_row):
row, col = self.num_to_coords(node.col_head.value - 1)
new_board[row][col] = node.row_head.value
return new_board
def print_board(self, board):
for row in board:
for cell in row:
print(f"{cell: >3}", end='')
print()
print()
print("#" * 80)
def unique_figure_postures(self, named_figures):
postures = set(named_figures)
for name, fig in named_figures:
postures.add((name, self.reflect(fig)))
all_postures = set(postures)
for name, posture in postures:
for _ in range(3):
posture = self.rotate(posture)
all_postures.add((name, posture))
return all_postures
# Generates entries for all possible positions of every figure's posture.
# Then the items of these entires will be linked into the 2 dimensional circular linked list
# The entry looks like:
# figure's name {board cells filled by figure} empty board's cells
# | | | | | | | | |
# 5 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ................
def generate_positions(self, postures, rows, cols):
def apply_posture(name, posture, y, x, wdth, hght):
# Flattening of 2d list
line = [cell for row in self.start_board for cell in row]
# Puts the figure onto the flattened start board
for r in range(hght):
for c in range(wdth):
if posture[r][c]:
num = (r + y) * cols + x + c
if line[num]:
return
line[num] = posture[r][c]
# And adds name into the beginning
line.insert(0, name)
return line
# makes columns header in a llist
yield [i for i in range(rows * cols + 1)]
for name, posture in postures:
posture_height = len(posture)
posture_width = len(posture[0])
for row in range(rows):
if row + posture_height > rows:
break
for col in range(cols):
if col + posture_width > cols:
break
new_line = apply_posture(name, posture, row, col, posture_width, posture_height)
if new_line:
yield new_line
def rotate(self, fig):
return tuple(zip(*fig[::-1]))
def reflect(self, fig):
return tuple(fig[::-1])
1 Answer 1
Low Hanging Fruit
I obtained a 27% speed-up with one tiny change.
For this speed-up, I didn't want to wait 2 hours for tests to run, so I used a 3x20 board. Elapsed times:
- 0:01:44 - before the change
- 0:01:16 - after the change
I'm sure you'd like to know what the change was. I can't drag it out much more, so here it is. I added this line to the Nodes
class:
__slots__ = ('value', 'up', 'down', 'left', 'right', 'row_head', 'col_head')
This eliminates the __dict__
member of the Nodes
objects. So instead of self.down
being interpreted as self.__dict__['down']
, the value referenced is in a predefined "slot" in the Nodes
object. Not only do the references become faster, but the object uses less space, which reduces the memory footprint, which increases locality of reference which again helps performance.
(Adding the __slots__
to Linked_list_2D
and Solver
didn't change my performance numbers at all.)
Interface
class Solver():
def __init__(self, board, figures, rows, cols, figures_naming_start):
...
This is an "unfriendly" interface. You have to pass all 5 values to the Solver()
. But some of these parameters are redundant.
You've got board
and you have rows
and cols
. Given a board
, the Solver
could determine rows = len(board)
and cols = len(board[0])
. Or given rows
and cols
, the Solver()
could construct an empty board.
figures_naming_start
is likely always going to be 1
. Why not use a default value of 1
for that parameter? Or pass a dictionary to figures
with the key names the "name" for the figures. And since figures
is a predefined set, why not default it to a class constant?
class Solver():
STANDARD_FIGURES = { 'I' : ((1, 1, 1, 1, 1)),
'Q' : ((1, 1, 1, 1),
(1, 0, 0, 0)),
...
}
def __init__(self, board=None, rows=None, cols=None, figures=STANDARD_FIGURES):
if board is None:
board = [[0] * cols for _ in range(len(rows))]
if rows is None and cols is None:
rows = len(board)
cols = len(board[0])
...
Usage:
solver = Solver(rows=3, cols=20)
solver.find_solutions()
Finding Solutions
When a solution is found, it is printed. What if you wanted to display it in a GUI of some kind, with coloured tiles?
It would be better for find_solutions()
to yield solution
, and then the caller could print the solutions, or display them in some fashion, or simply count them:
solver = Solver(rows=3, cols=20)
for solution in solver.find_solutions():
solver.print_board(solution)
Progress
The progress / timing messages should be presented via the logging
module, where they could be printed, or written to a file, or turned off all together.
Explore related questions
See similar questions with these tags.
insert_col
ordelete_col
methods need comments? \$\endgroup\$code # comment
so that the code is read and when not understood, you just glance sideways. What i feel lack is the big picture like a docstring todef dlx_alg(self, llist, board):
. Your linkedlist for example is too long. Compare with this \$\endgroup\$