4
\$\begingroup\$

In the context of a Gibbs sampler, I profiled my code and my major bottleneck is the following:

I need to compute the likelihood of N points assuming they have been drawn from N normal distributions (with different means but same variance).

Here are two ways to compute it:

import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import norm
# Toy data
y = np.random.uniform(low=-1, high=1, size=100) # data points
loc = np.zeros(len(y)) # means
# Two alternatives
%timeit multivariate_normal.logpdf(y, mean=loc, cov=1)
%timeit sum(norm.logpdf(y, loc=loc, scale=1))
  • The first: use the recently implemented multivariate_normal of scipy. Build the equivalent N-dimensional gaussian and compute the (log)probability of a N-dimensional y.

    1000 loops, best of 3: 1.33 ms per loop

  • The second: use the traditional norm function of scipy. Compute the individual (log)probability of every point y and then sum the results.

    10000 loops, best of 3: 130 μs per loop

Since this is part of a Gibbs sampler, I need to repeat this computation around 10.000 times, and therefore I need it to be as fast as possible.

How can I improve it?

(either from python or calling Cython, R or whatever)

200_success
145k22 gold badges190 silver badges478 bronze badges
asked Nov 13, 2014 at 11:22
\$\endgroup\$
1
  • 1
    \$\begingroup\$ It's a little faster (about 20-30%) if you use norm.logpdf(y, loc=loc, scale=1).sum() instead of sum(norm.logpdf(y, loc=loc, scale=1)) as sum is a generic python function, whereas .sum() is an optimized numpy function. \$\endgroup\$ Commented Nov 15, 2014 at 3:07

1 Answer 1

4
\$\begingroup\$

You should use a line profiler tool to examine what the slowest parts of the code are. It sounds like you did that for your own code, but you could keep going and profile the source code that NumPy and SciPy use when calculating your quantity of interest. The [Line profiler](https://pypi.python.org/pypi/line_profiler/) module is my favorite.

import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import norm
%lprun -f norm.logpdf norm.logpdf(x=np.random.random(1000000), \
 loc=np.random.random(1000000), \
 scale = np.random.random())

Timer unit: 1e-06 s
Total time: 0.14831 s
File: /opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/scipy/stats/_distn_infrastructure.py
Function: logpdf at line 1578
Line # Hits Time Per Hit % Time Line Contents
==============================================================
 1578 def logpdf(self, x, *args, **kwds):
 1579 """
 1580 Log of the probability density function at x of the given RV.
 1581 
 1582 This uses a more numerically accurate calculation if available.
 1583 
 1584 Parameters
 1585 ----------
 1586 x : array_like
 1587 quantiles
 1588 arg1, arg2, arg3,... : array_like
 1589 The shape parameter(s) for the distribution (see docstring of the
 1590 instance object for more information)
 1591 loc : array_like, optional
 1592 location parameter (default=0)
 1593 scale : array_like, optional
 1594 scale parameter (default=1)
 1595 
 1596 Returns
 1597 -------
 1598 logpdf : array_like
 1599 Log of the probability density function evaluated at x
 1600 
 1601 """
 1602 1 14 14.0 0.0 args, loc, scale = self._parse_args(*args, **kwds)
 1603 1 23 23.0 0.0 x, loc, scale = map(asarray, (x, loc, scale))
 1604 1 2 2.0 0.0 args = tuple(map(asarray, args))
 1605 1 13706 13706.0 9.2 x = asarray((x-loc)*1.0/scale)
 1606 1 33 33.0 0.0 cond0 = self._argcheck(*args) & (scale > 0)
 1607 1 5331 5331.0 3.6 cond1 = (scale > 0) & (x >= self.a) & (x <= self.b)
 1608 1 5625 5625.0 3.8 cond = cond0 & cond1
 1609 1 84 84.0 0.1 output = empty(shape(cond), 'd')
 1610 1 6029 6029.0 4.1 output.fill(NINF)
 1611 1 11459 11459.0 7.7 putmask(output, (1-cond0)+np.isnan(x), self.badvalue)
 1612 1 1093 1093.0 0.7 if any(cond):
 1613 1 58499 58499.0 39.4 goodargs = argsreduce(cond, *((x,)+args+(scale,)))
 1614 1 6 6.0 0.0 scale, goodargs = goodargs[-1], goodargs[:-1]
 1615 1 46401 46401.0 31.3 place(output, cond, self._logpdf(*goodargs) - log(scale))
 1616 1 4 4.0 0.0 if output.ndim == 0:
 1617 return output[()]
 1618 1 1 1.0 0.0 return output

It looks like a not-insignificant amount of time is being spent checking and removing invalid arguments from the function input. If you can be sure you will never need to use that feature, just write your own function to calculate the logpdf.

Plus, if you are going to be multiplying probabilities (i.e. adding log probabilities), you could use algebra to simplify and factor out common terms from the summand for the normal distribution's pdf. That will lower the number of function calls to np.log etc. I did this in a hurry, so I probably made a math mistake, but:

def my_logpdf_sum(x, loc, scale):
 root2 = np.sqrt(2)
 root2pi = np.sqrt(2*np.pi)
 prefactor = - x.size * np.log(scale * root2pi)
 summand = -np.square((x - loc)/(root2 * scale)) 
 return prefactor + summand.sum()
# toy data
y = np.random.uniform(low=-1, high=1, size=1000) # data points
loc = np.zeros(y.shape)
​
# timing
%timeit multivariate_normal.logpdf(y, mean=loc, cov=1)
%timeit np.sum(norm.logpdf(y, loc=loc, scale=1))
%timeit my_logpdf_sum(y, loc, 1)
1 loops, best of 3: 156 ms per loop
10000 loops, best of 3: 125 μs per loop
The slowest run took 4.55 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 16.3 μs per loop
answered Aug 3, 2015 at 13:04
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.