In the context of a Gibbs sampler, I profiled my code and my major bottleneck is the following:
I need to compute the likelihood of N points assuming they have been drawn from N normal distributions (with different means but same variance).
Here are two ways to compute it:
import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import norm
# Toy data
y = np.random.uniform(low=-1, high=1, size=100) # data points
loc = np.zeros(len(y)) # means
# Two alternatives
%timeit multivariate_normal.logpdf(y, mean=loc, cov=1)
%timeit sum(norm.logpdf(y, loc=loc, scale=1))
The first: use the recently implemented
multivariate_normal
of scipy. Build the equivalentN
-dimensional gaussian and compute the (log)probability of aN
-dimensionaly
.1000 loops, best of 3: 1.33 ms per loop
The second: use the traditional
norm
function of scipy. Compute the individual (log)probability of every pointy
and then sum the results.10000 loops, best of 3: 130 μs per loop
Since this is part of a Gibbs sampler, I need to repeat this computation around 10.000 times, and therefore I need it to be as fast as possible.
How can I improve it?
(either from python or calling Cython, R or whatever)
1 Answer 1
You should use a line profiler tool to examine what the slowest parts of the code are. It sounds like you did that for your own code, but you could keep going and profile the source code that NumPy and SciPy use when calculating your quantity of interest. The [Line profiler](https://pypi.python.org/pypi/line_profiler/)
module is my favorite.
import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import norm
%lprun -f norm.logpdf norm.logpdf(x=np.random.random(1000000), \
loc=np.random.random(1000000), \
scale = np.random.random())
Timer unit: 1e-06 s
Total time: 0.14831 s
File: /opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/scipy/stats/_distn_infrastructure.py
Function: logpdf at line 1578
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1578 def logpdf(self, x, *args, **kwds):
1579 """
1580 Log of the probability density function at x of the given RV.
1581
1582 This uses a more numerically accurate calculation if available.
1583
1584 Parameters
1585 ----------
1586 x : array_like
1587 quantiles
1588 arg1, arg2, arg3,... : array_like
1589 The shape parameter(s) for the distribution (see docstring of the
1590 instance object for more information)
1591 loc : array_like, optional
1592 location parameter (default=0)
1593 scale : array_like, optional
1594 scale parameter (default=1)
1595
1596 Returns
1597 -------
1598 logpdf : array_like
1599 Log of the probability density function evaluated at x
1600
1601 """
1602 1 14 14.0 0.0 args, loc, scale = self._parse_args(*args, **kwds)
1603 1 23 23.0 0.0 x, loc, scale = map(asarray, (x, loc, scale))
1604 1 2 2.0 0.0 args = tuple(map(asarray, args))
1605 1 13706 13706.0 9.2 x = asarray((x-loc)*1.0/scale)
1606 1 33 33.0 0.0 cond0 = self._argcheck(*args) & (scale > 0)
1607 1 5331 5331.0 3.6 cond1 = (scale > 0) & (x >= self.a) & (x <= self.b)
1608 1 5625 5625.0 3.8 cond = cond0 & cond1
1609 1 84 84.0 0.1 output = empty(shape(cond), 'd')
1610 1 6029 6029.0 4.1 output.fill(NINF)
1611 1 11459 11459.0 7.7 putmask(output, (1-cond0)+np.isnan(x), self.badvalue)
1612 1 1093 1093.0 0.7 if any(cond):
1613 1 58499 58499.0 39.4 goodargs = argsreduce(cond, *((x,)+args+(scale,)))
1614 1 6 6.0 0.0 scale, goodargs = goodargs[-1], goodargs[:-1]
1615 1 46401 46401.0 31.3 place(output, cond, self._logpdf(*goodargs) - log(scale))
1616 1 4 4.0 0.0 if output.ndim == 0:
1617 return output[()]
1618 1 1 1.0 0.0 return output
It looks like a not-insignificant amount of time is being spent checking and removing invalid arguments from the function input. If you can be sure you will never need to use that feature, just write your own function to calculate the logpdf
.
Plus, if you are going to be multiplying probabilities (i.e. adding log probabilities), you could use algebra to simplify and factor out common terms from the summand for the normal distribution's pdf. That will lower the number of function calls to np.log
etc. I did this in a hurry, so I probably made a math mistake, but:
def my_logpdf_sum(x, loc, scale):
root2 = np.sqrt(2)
root2pi = np.sqrt(2*np.pi)
prefactor = - x.size * np.log(scale * root2pi)
summand = -np.square((x - loc)/(root2 * scale))
return prefactor + summand.sum()
# toy data
y = np.random.uniform(low=-1, high=1, size=1000) # data points
loc = np.zeros(y.shape)
# timing
%timeit multivariate_normal.logpdf(y, mean=loc, cov=1)
%timeit np.sum(norm.logpdf(y, loc=loc, scale=1))
%timeit my_logpdf_sum(y, loc, 1)
1 loops, best of 3: 156 ms per loop
10000 loops, best of 3: 125 μs per loop
The slowest run took 4.55 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 3: 16.3 μs per loop
Explore related questions
See similar questions with these tags.
norm.logpdf(y, loc=loc, scale=1).sum()
instead ofsum(norm.logpdf(y, loc=loc, scale=1))
as sum is a generic python function, whereas.sum()
is an optimized numpy function. \$\endgroup\$