I'm thinking I implemented it optimally, but somehow it's much slower than what should be much slower, np.argmax(np.abs(x))
. Where am I off?
Code rationale & results
- Mathematically,
abs
issqrt(real**2 + imag**2)
, butargmax(abs(x)) == argmax(abs(x)**2)
, so no need for square root np.abs(x)
also allocates and writes an array. Instead I overwrite a single value,current_abs2
, which should eliminate allocation and only leave writing- Argmax logic should be identical to NumPy's (I've not checked but only one best way to do it?)
- Views (
R
,I
) are for... I don't recall, saw somewhere
So savings are in dropping sqrt
and len(x)
-sized allocation. Yet it's much slower...
%timeit np.argmax(np.abs(x))
%timeit abs_argmax(x.real, x.imag)
409 μs ± 2.33 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
3.09 ms ± 14.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Here's the generated C code, just the function; the whole _optimized.c
is 26000 lines.
The following Numba achieves 108 μs
, very satisfactory, though I'm interested in why Cython fails.
Code
import cython
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int abs_argmax(double[:] re, double[:] im):
# initialize variables
cdef Py_ssize_t N = re.shape[0]
cdef double[:] R = re # view
cdef double[:] I = im # view
cdef Py_ssize_t i = 0
cdef int max_idx = 0
cdef double current_max = 0
cdef double current_abs2 = 0
# main loop
while i < N:
current_abs2 = R[i]**2 + I[i]**2
if current_abs2 > current_max:
max_idx = i
current_max = current_abs2
i += 1
# return
return max_idx
Setup & execution
I use python setup.py build_ext --inplace
, setup.py
shown at bottom. Then,
import numpy as np
from _optimized import abs_argmax
x = np.random.randn(100000) + 1j*np.random.randn(100000)
%timeit np.argmax(np.abs(x))
%timeit abs_argmax(x.real, x.imag)
setup.py
(I forget the rationale, just took certain recommendations)
from distutils import _msvccompiler
_msvccompiler.PLAT_TO_VCVARS['win-amd64'] = 'amd64'
from setuptools import setup, Extension
from Cython.Build import cythonize
import numpy as np
setup(
ext_modules=cythonize(Extension("_optimized", ["_optimized.pyx"]),
language_level=3),
include_dirs=[np.get_include()],
)
Environment
Windows 11, i7-13700HX CPU, Python 3.11.4, Cython 3.0.0, setuptools 68.0.0, numpy 1.24.4
3 Answers 3
In a generated code (slightly edited for readability)
__pyx_v_current_abs2 = (
pow((*((double *) ((__pyx_v_R.data + __pyx_t_2 * __pyx_v_R.strides[0]) ))), 2.0) +
pow((*((double *) ((__pyx_v_I.data + __pyx_t_3 * __pyx_v_I.strides[0]) ))), 2.0)
);
I do not like calls to pow
. Apparently, cython is not smart enough, and transpiles ** 2
into a function call, rather that a simple multiplication. Try to help it:ᅠᅠᅠᅠᅠᅠ
current_abs2 = R[i]*R[i] + I[i]*I[i]
and see what happens. As of the rest - failed branch predictions, missed vectorization, etc - we may only theorize.
The OP code suffers from at least a 6x factor of suckage. The question would benefit from posting godbolt links which examine the differences in the generated code.
Often the trouble boils down to
- inability to vectorize
- poor branch prediction
if current_abs2 > current_max:
It seems plausible that branch prediction is falling apart there.
Consider pre-computing a bunch of abs2
values,
sorting them descending, and returning the zeroth element.
It's O(N log N), so theoretically worse by a factor of log N.
But it might interact with memory hierarchy better in practice.
Bench, and report back.
-
\$\begingroup\$ I don't know if branch prediction-based improvement is possible, the data is randomized. Sorting is surely much slower than argmax, and not just per log N. Vectorization, perhaps, though should be up to x4 difference on single core from what I've read. I don't know how to make godbolt work but I added a C excerpt. \$\endgroup\$OverLordGoldDragon– OverLordGoldDragon2023年07月22日 01:56:02 +00:00Commented Jul 22, 2023 at 1:56
Thanks to vnp's answer; here's np.argmax(np.abs(x))
vs Numba vs Cython:
414 μs ± 8.09 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
89.9 μs ± 1.2 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
66.8 μs ± 1.11 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Code
import cython
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int abs_argmax(double[:] R, double[:] I):
# initialize variables
cdef Py_ssize_t N = R.shape[0]
cdef Py_ssize_t i = 0
cdef int max_idx = 0
cdef double current_max = 0
cdef double current_abs2 = 0
# main loop
for i in range(N):
current_abs2 = R[i]*R[i] + I[i]*I[i]
if current_abs2 > current_max:
max_idx = i
current_max = current_abs2
return max_idx
while i < N; i += 1
. Use a proper for loop, which can benefit from loop unrolling and other compiler optimizations. I presume your build is an optimized build, but I don’t know how Cython builds by default. Double-check that. \$\endgroup\$cdef double[:] R = re
is not a copy? Why do you need this anyway? Doesn’tx.real, x.imag
create copies too? \$\endgroup\$x.real
x.imag
(poor MATLAB?)). Also just found this. I don't recall what's up with the while, again I have some precedent from months ago. I tested both your suggestions, they didn't help, but the code is certainly cleaner. \$\endgroup\$