I have numpy array/matrix of shape (4096,4096) and an array of elements that should be set to zero. I have found function numpy.in1d that works fine but is very slow for my computations. I was wondering if exists some faster way for this execution because I need to repeat this on very large number of matrices so every optimization is helpful.
Here is example:
The numpy array looks like this:
npArr = np.array([
[1, 4, 5, 5, 3],
[2, 5, 6, 6, 1],
[0, 0, 1, 0, 0],
[3, 3, 2, 4, 3]])
and another array is :
arr = np.array([3,5,8])
The numpy array npArr should look like after the replacement:
array([[ 1, 4, 0, 0, 0],
[ 2, 0, 6, 6, 1],
[ 0, 0, 1, 0, 0],
[ 0, 0, 2, 4, 0]])
-
Your problem description is too vague. Give us an example with a small quadratic matrix and the expected result.timgeb– timgeb2017年05月12日 17:01:27 +00:00Commented May 12, 2017 at 17:01
-
I am sorry for that. Is it clearer now?Đorđe Ivanović– Đorđe Ivanović2017年05月12日 17:30:38 +00:00Commented May 12, 2017 at 17:30
-
Yes, much better!timgeb– timgeb2017年05月13日 17:24:05 +00:00Commented May 13, 2017 at 17:24
-
Did either of the posted solutions work for you?Divakar– Divakar2017年05月16日 11:34:12 +00:00Commented May 16, 2017 at 11:34
-
Sorry for late comment, your answer is great. Thank you once againĐorđe Ivanović– Đorđe Ivanović2017年05月19日 18:22:01 +00:00Commented May 19, 2017 at 18:22
3 Answers 3
If you have numba you can solve this with a custom function that doesn't need an intermediate mask:
import numpy as np
import numba as nb
@nb.njit
def replace_where(arr, needle, replace):
arr = arr.ravel()
needles = set(needle)
for idx in range(arr.size):
if arr[idx] in needles:
arr[idx] = replace
This gives the correct result for your example:
npArr = np.array([[1, 4, 5, 5, 3],
[2, 5, 6, 6, 1],
[0, 0, 1, 0, 0],
[3, 3, 2, 4, 3]])
arr = np.array([3,5,8])
replace_where(npArr, arr, 0)
print(npArr)
# array([[1, 4, 0, 0, 0],
# [2, 0, 6, 6, 1],
# [0, 0, 1, 0, 0],
# [0, 0, 2, 4, 0]])
And it should be really, really fast. I timed it for several array sizes and it was 5-20 times faster (depending on the sizes, especially the arr size) than np.in1d.
Comments
Here's an alternative using np.searchsorted -
def in1d_alternative_2D(npArr, arr):
idx = np.searchsorted(arr, npArr.ravel())
idx[idx==len(arr)] = 0
return arr[idx].reshape(npArr.shape) == npArr
It assumes arr to be sorted. If it's not, we need to sort and then use the posted method.
Sample run -
In [90]: npArr = np.array([[1, 4, 5, 5, 3],
...: [2, 5, 6, 6, 1],
...: [0, 0, 1, 0, 0],
...: [3, 3, 2, 14, 3]])
...:
...: arr = np.array([3,5,8])
...:
In [91]: in1d_alternative_2D(npArr, arr)
Out[91]:
array([[False, False, True, True, True],
[False, True, False, False, False],
[False, False, False, False, False],
[ True, True, False, False, True]], dtype=bool)
In [92]: npArr[in1d_alternative_2D(npArr, arr)] = 0
In [93]: npArr
Out[93]:
array([[ 1, 4, 0, 0, 0],
[ 2, 0, 6, 6, 1],
[ 0, 0, 1, 0, 0],
[ 0, 0, 2, 14, 0]])
Benchmarking against numpy.in1d
Equivalent solution using np.in1d would be :
np.in1d(npArr, arr).reshape(npArr.shape)
Let's time our proposed one against it and also verify results for the sizes mentioned in the question.
In [85]: # (4096, 4096) shaped 'npArr' and search array 'arr' of 1000 elems
...: npArr = np.random.randint(0,10000,(4096,4096))
...: arr = np.sort(np.random.choice(10000, 1000, replace=0 ))
...:
In [86]: out1 = np.in1d(npArr, arr).reshape(npArr.shape)
...: out2 = in1d_alternative_2D(npArr, arr)
...:
In [87]: np.allclose(out1, out2)
Out[87]: True
In [88]: %timeit np.in1d(npArr, arr).reshape(npArr.shape)
1 loops, best of 3: 3.04 s per loop
In [89]: %timeit in1d_alternative_2D(npArr, arr)
1 loops, best of 3: 1 s per loop
Comments
Another solution using numpy broadcasting:
np.min(np.where(npArr[None,:,:] == arr[:,None,None], 0, a),0)
Out[730]:
array([[1, 4, 0, 0, 0],
[2, 0, 6, 6, 1],
[0, 0, 1, 0, 0],
[0, 0, 2, 4, 0]])
1 Comment
npArr of size (4096,4096) and arr of size 3 this creates a 3 * 4096 * 4096 array (50 million elements). And because you use where and the condition you actually have several of those huge arrays (I would guess 3 but I'm not sure). That can quickly exhaust the RAM.