I want to write an algorithm to count how many islands are in a given matrix. Consider for example:
A = [
[1, 2, 1, 3],
[2, 2, 3, 2],
[3, 3, 2, 3]
]
Directly adjacent (north, south, east, west, but not diagonally) numbers constitute an island, in this example matrix we have 9.
I've written the following code which works and performs fine. Any suggestions on how could I write this more cleanly and to make it work even faster?
def clean_neighbours(matrix, this_row, this_col):
cell_value = matrix[this_row][this_col]
if cell_value == 0:
return
matrix[this_row][this_col] = 0
number_of_rows = len(matrix)
number_of_columns = len(matrix[0])
for shift in (
(-1,0), (1,0), (0,-1), (0,1)
):
row, col = [x+y for x,y in zip((this_row, this_col), shift)]
if (row >= 0 and row < number_of_rows) and (
col >= 0 and col < number_of_columns
):
if matrix[row][col] == cell_value:
clean_neighbours(matrix, row, col)
def count_adjacent_islands(matrix):
number_of_islands = 0
for row_index, row in enumerate(matrix):
for column_index, _ in enumerate(row):
if matrix[row_index][column_index] != 0:
number_of_countries += 1
clean_neighbours(matrix, row_index, column_index)
return number_of_countries
import random
A = [ [random.randint(-1000,1000) for e in range(0,1000)] for e in range(0,1000) ]
print(count_adjacent_islands(A))
-
1\$\begingroup\$ Please don't edit your question so that it invalidates current answers. \$\endgroup\$Ben A– Ben A2019年12月16日 10:28:20 +00:00Commented Dec 16, 2019 at 10:28
-
2\$\begingroup\$ Please do not update the code in your question to incorporate feedback from answers, doing so goes against the Question + Answer style of Code Review. This is not a forum where you should keep the most updated version in your question. Please see what you may and may not do after receiving answers . Also, if you want to delete your question when it has been answered, the process for that is a bit more complicated. See "If I flag my question with a request to delete it, what will happen?" in this post \$\endgroup\$Simon Forsberg– Simon Forsberg2019年12月16日 11:27:37 +00:00Commented Dec 16, 2019 at 11:27
1 Answer 1
bug
in count_adjacent_islands
, number_of_islands = 0
should be number_of_countries = 0
mutate original argument
Most of the time, it's a bad idea to change any of the arguments to a function unless explicitly expected. So you better take a copy of the matrix first:
matrix_copy = [row[:] for row in matrix]
tuple unpacking
instead of for shift in ((-1,0), (1,0), (0,-1), (0,1)):
, you can do for dx, dy in ((-1, 0), (1, 0), (0, -1), (0, 1)):
, then row, col = [x+y for x,y in zip((this_row, this_col), shift)]
can be expressed a lot clearer: row, col = x + dx, y + dy
continue
instead of keep nesting if
conditions, you can break out of that iteration earlier if the conditions are not fulfilled
for row_index, row in enumerate(matrix):
for column_index, _ in enumerate(row):
if matrix[row_index][column_index] != 0:
number_of_islands += 1
clean_neighbours(matrix, row_index, column_index)
can become:
for row_index, row in enumerate(matrix_copy):
for column_index, _ in enumerate(row):
if matrix_copy[row_index][column_index] == 0:
continue
number_of_islands += 1
clean_neighbours2(matrix_copy, row_index, column_index)
saving 1 level of indentation on the code that actually does the lifting. This is not much in this particular case, but with larger nested conditions, this can make things a lot clearer, and save a lot of horizontal screen estate
recursion
If there are some larger islands, you will run into the recursion limit. Better would be to transform this to a queue and a loop
from collections import deque
def clean_neighbours2(matrix, x, y):
cell_value = matrix[x][y]
if cell_value == 0:
return
matrix[x][y] = 0
queue = deque([(x,y)])
while queue:
x, y = queue.pop()
for dx, dy in ((-1, 0), (1, 0), (0, -1), (0, 1)):
row, col = x + dx, y + dy
if (
0 <= row < len(matrix)
and 0 <= col < len(matrix[0])
and not matrix[row][col] == 0
):
continue
if matrix[row][col] == cell_value:
queue.append((row, col))
matrix[row][col] = 0
def count_adjacent_islands2(matrix):
matrix_copy = [row[:] for row in matrix]
number_of_islands = 0
for row_index, row in enumerate(matrix_copy):
for column_index, _ in enumerate(row):
if matrix_copy[row_index][column_index] == 0:
continue
number_of_islands += 1
clean_neighbours2(matrix_copy, row_index, column_index)
return number_of_islands
For the sample data you provided, this code took 3s compared to 4s for the original on my machine
alternative approach
Using numba
and numpy
, and a slight rewrite to accomodate for numba compatibilities:
from numba import jit
import numpy as np
@jit()
def clean_neighbours_jit(matrix, x, y):
cell_value = matrix[x, y]
if cell_value == 0:
return
matrix[x, y] = 0
queue = [(x, y)]
row_length, column_length = matrix.shape
while queue:
x, y = queue.pop()
for dx, dy in ((-1, 0), (1, 0), (0, -1), (0, 1)):
row, col = x + dx, y + dy
if (
not 0 <= row < row_length
or not 0 <= col < column_length
or matrix[row, col] != cell_value
):
continue
queue.append((row, col))
matrix[row, col] = 0
@jit()
def count_adjacent_islands_jit(matrix):
matrix_copy = matrix.copy()
number_of_islands = 0
row_length, column_length = matrix_copy.shape
for row_index in range(row_length):
for column_index in range(column_length):
if matrix_copy[row_index, column_index] == 0:
continue
number_of_islands += 1
clean_neighbours_jit(matrix_copy, row_index, column_index)
return number_of_islands
This expects a numpy array as matrix
, (for example: count_adjacent_islands_jit(np.array(A))
) but does the job in about 200 to 300ms, (about 80ms spent on converting A
to an np.array
), so more than 10x speedup.
-
\$\begingroup\$ >about 80ms spent on converting A to an np.array. How did you carried out this analysis? \$\endgroup\$Blasco– Blasco2019年12月14日 10:26:58 +00:00Commented Dec 14, 2019 at 10:26
-
\$\begingroup\$ Thank you very much for your review. Great ideas, I haven't work with numba before, quite impressive! \$\endgroup\$Blasco– Blasco2019年12月14日 10:27:38 +00:00Commented Dec 14, 2019 at 10:27
-
\$\begingroup\$ The 80ms comes from comparing
A2 = np.array(A); %timeit count_adjacent_islands_jit(A2)
versus%timeit count_adjacent_islands_jit(np.array(A))
(in a jupyter notebook). And numba is impressive indeed. Only the shift to numpy seemed to slow the code, and jitting the original method didn't give any advantage \$\endgroup\$Maarten Fabré– Maarten Fabré2019年12月14日 10:41:56 +00:00Commented Dec 14, 2019 at 10:41