I working on optimizing some code for speed in a python module. I have pinned down the bottleneck and is a code snippet which calculates three np.ndarray
s. Namely the following code:
xh = np.multiply(K_Rinv[0, 0], x )
xh += np.multiply(K_Rinv[0, 1], y)
xh += np.multiply(K_Rinv[0, 2], h)
yh = np.multiply(K_Rinv[1, 0], x)
yh += np.multiply(K_Rinv[1, 1], y)
yh += np.multiply(K_Rinv[1, 2], h)
q = np.multiply(K_Rinv[2, 0], x)
q += np.multiply(K_Rinv[2, 1], y)
q += np.multiply(K_Rinv[2, 2], h)
where x
, y
and h
are np.ndarray
s with shape (4206, 5749) and K_Rinv
is a np.ndarray
witch shape (3, 3). This code snippet is called multiple times and takes more than 50% of the time of the whole code. Is there a way to speed this up ? Or is it just as it is and can't be speed up.
Every idea and criticism is welcome.
-
1\$\begingroup\$ Can you please title your post for what the code presented is to accomplish? Possibly, you get advice even more helpful answers when you provide more context. \$\endgroup\$greybeard– greybeard2020年01月10日 10:26:58 +00:00Commented Jan 10, 2020 at 10:26
1 Answer 1
First a quick non-performance related mode: np.multiply
could simply be replaced by *
in your case since it's basically scalar x array. That would make the code less verbose.
xh = K_Rinv[0, 0] * x
xh += K_Rinv[0, 1] * y
xh += K_Rinv[0, 2] * h
My first intuition on your problem was that it lends itself to be rewritten as scalar product. See the following example:
import numpy as np
K_Rinv = np.random.rand(3, 3)
x = np.random.rand(4206, 5749)
y = np.random.rand(4206, 5749)
h = np.random.rand(4206, 5749)
xyh = np.stack((x, y, h))
xh_dot = np.dot(xyh, K_Rinv[0, :])
But it turns out, that this is 4x slower than what you have.
The "vectorized" version below still turns is about 1.5x slower than what you have for my toy dataset.
xh_vec = (xyh * K_Rinv[0, :]).sum(axis=2)
A quick check with np.allclose(...)
does however confirm, that the results are indeed (numerically) equivalent.
You might see some improvement by putting your code into a separate function and apply the @jit(nopython=True)
decorator from numba if that's an option. numba is a Just-in-Time compiler for Python code, and is also aware of some, but not all, numpy functions. For the simple test above, the time was a bit more than half than that without using numba. When measuring, take care to exclude the first run, since that's when the compilation takes place.
Unfortunately, I'm a little bit short on time at the moment to look into this further.