This function is working exactly as I want, only, it's taking too long.
For speed ups, I've tried to do as much as I can before the main for
loop by declaring each function as a local variable, I've also switched from using pandas dataframes to numpy arrays and decreased the outputted dpi.
This function is being fed large amounts of data so any speed improvement suggestions will be much appreciated. I don't know any Cython (or C) but would be willing to learn some if it was going to dramatically improve performance. I also welcome any suggestions on how I could improve the style of my code.
import os
import logging
import traceback
import warnings
from itertools import chain
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
NO_GROUPING_NAME = 'NoGrouping'
plt.style.use('ggplot')
def cdf_plot(total_param_values):
"""
Given a 3-deep nested dictionary, cdf_plot saves a cumulative frequency
distribution plot out of the values of each inner-most dictionary. This will
be a scatter plot with colours corresponding to the keys of the dict being plotted.
If a 2-deep nested dictionary has key == NO_GROUPING_NAME then the corresponding
value will be a dictionary of only one key with value of one list of floats, so
the plot will only have one colour. In this case, no legend is drawn.
The cumulative frequency distribution data is formed from a list of values
(call the list all_x_values) by plotting the sorted values on the x-axis.
A corresponding y-value (for a given x-value) is equal to norm.ppf(i/len(all_x_values))
where i is the index of the given x-value in all_x_values and norm.ppf is a
function from scipy.stats (Percent point function (inverse of cdf — percentiles)).
Parameters
----------
total_param_values : { string : { string : { string : list of floats}}}
This corresponds to {p_id : {grouping : {group_instance : values}}}
"""
# Do as much as possible before loop
fig = plt.figure()
add_subplot = fig.add_subplot
textremove = fig.texts.remove
xlabel = plt.xlabel
ylabel = plt.ylabel
yticks = plt.yticks
cla = plt.cla
savefig = plt.savefig
figtext = plt.figtext
currentfigtext = None
colours = ('b', 'g', 'r', 'c','teal', 'm','papayawhip', 'y', 'k',
'aliceblue', 'aqua', 'forestgreen', 'deeppink', 'blanchedalmond',
'burlywood', 'darkgoldenrod')
nparray = np.array
nanstd = np.nanstd
nanmean = np.nanmean
npsort = np.sort
isnan = np.isnan
vectorize = np.vectorize
normppf = norm.ppf
chainfrom_iterable = chain.from_iterable
# Prepare yticks
y_labels = [0.0001, 0.001, 0.01, 0.10, 0.25, 0.5,
0.75, 0.90, 0.99, 0.999, 0.9999]
y_pos = [normppf(i) for i in y_labels]
try:
# Hide annoying warning
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=FutureWarning)
for p_id, p_id_dict in total_param_values.items():
for grouping, grouping_dict in p_id_dict.items():
#check whether plot already exists
save_name = p_id + grouping + '.png'
if os.path.exists(save_name):
continue
# Keep count of position in colour cycle
colour_count = 0
ax = add_subplot(111)
axscatter = ax.scatter
# Work out normalising function
chn = chainfrom_iterable(grouping_dict.values())
flattened = list(chn)
std = nanstd(flattened)
mean = nanmean(flattened)
if std:
two_ops = lambda x: (x - mean) / std
v_norm = vectorize(two_ops)
else:
one_op = lambda x: x - mean
v_norm = vectorize(one_op)
# Keep track of total number of values plotted this iteration
total_length = 0
for group_instance, values in grouping_dict.items():
values = nparray(values)
values = npsort(values[~isnan(values)])
length = len(values)
total_length += length
# Skip graphing any empty array
if not length:
continue
# Normalise values to be ready for plotting on x-axis
values = v_norm(values)
# Prepare y-values as described in function doc
y = [normppf(i/length) for i in range(length)]
axscatter(values, y, color=colours[colour_count % len(colours)],
label=group_instance + ' (' + str(length) + ')',
alpha=0.6)
colour_count += 1
# If no values were found, clear axis and skip to next iteration
if not total_length:
cla()
continue
if grouping != NO_GROUPING_NAME:
try:
ax.legend(loc='lower right', title=grouping + ' ('
+ str(total_length) + ')', fontsize=6,
scatterpoints=1)
except ValueError:
print('EXCEPTION: Weird error with legend() plotting,\
something about taking the max of an empty sequence')
pass
else:
# Turn off legend but display total_length in bottom right corner
ax.legend().set_visible(False)
if currentfigtext is not None:
textremove(currentfigtext)
currentfigtext = figtext(0.99, 0.01, 'Number of points = '
+ str(total_length),
horizontalalignment='right')
xlabel('')
ylabel('')
yticks(y_pos, y_labels)
savefig(save_name, dpi=60)
cla()
except Exception as e:
print('It broke.............', e)
print('Variable dump........')
print('grouping {}, group_instance {}, values {}, length {}, ax {}, y {},\
colour_count {}, figtext {} std {}, mean {},\
save_name {}'.format(grouping, group_instance, values, length,
ax, y, colour_count, currentfigtext, std, mean, save_name))
logging.error(traceback.format_exc())
raise
# Make sure no figures remain
plt.close('all')
# In an attempt to make the nesting clear I've written test_values out in a
# weird way (weird to me at least)
test_values = {
'p_1' : {
'NoGrouping' : {
'' : list(np.random.rand(100))
},
'Sky' : {
'Blue' : list(np.random.rand(100)),
'Red' : list(np.random.rand(100)),
'Grey' : list(np.random.rand(100))
}
},
'p_2' : {
'NoGrouping' : {
'' : list(np.random.rand(100))
},
'Sky' : {
'Blue' : list(np.random.rand(100)),
'Red' : list(np.random.rand(100)),
'Grey' : list(np.random.rand(100))
}
}
}
cdf_plot(test_values)
-
\$\begingroup\$ Have you run line profiling on your code? \$\endgroup\$Curt F.– Curt F.2016年01月13日 16:19:10 +00:00Commented Jan 13, 2016 at 16:19
1 Answer 1
Your function is too large to digest in one sitting, and without running on my machine. So I'll start with a few quick observations.
The idea of putting a plot function (axscatter
) in an inner most loop just feels wrong. It makes more sense to build arrays of values, and call the plot just once per subplot. But if the number of values that you plot per call is large, and the number of iterations is small, that isn't so big of a deal. It's hard to know just how many iterations this code performs.
The use of vectorize
on functions like lambda x: (x - mean) / std
is unnecessary. vectorize
does not speed up code; it just makes it easier to apply scalar operations to arrays. x-mean
is already a numpy array operation, as is ()/std
.
The use of list comprehension y = [normppf(i/length) for i in range(length)]
is, I suspect, unnecessary. But I don't know exactly what normppf
does. Does it accept an array, or just scalar values?
Organizing this as one big function makes it hard to test and optimize pieces. I'd put each level of iteration in a separate function. I should be able to test the data manipulation functions without plotting, and be able to plot without generating a new set of data.
I like to see, at a glance, where the code is massaging the data, and where it is plotting. And with plotting it is nice to separate out the actions that actually plot from those that tweak the appearance. I also like to clearly identify when it's calling imported code. You do that with np....
but not other functions.
-
\$\begingroup\$ Thanks, your suggestions are great and I will be using them all.
normppf
does indeed accept an array so the list comprehension is redundant. \$\endgroup\$Ben Carr– Ben Carr2015年11月30日 09:20:34 +00:00Commented Nov 30, 2015 at 9:20
Explore related questions
See similar questions with these tags.