I am experimenting with making my own re-usable libraries. Therefore, I decided to start with some of the plots that I generally use during development/debugging to check what is actually inside my data.
My questions therefore are:
- Would this be the correct way of writing a library?
- Are there any pieces of my code that could be improved and/or be made simpler?
The library (needing a new title as the old title is already taken
#! /usr/bin/env python
from mpl_toolkits.mplot3d import Axes3D
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import statsmodels.api as sm
def linear(x,a,b):
return a*x+b
def quadratic(x,a,b,c):
return a*x**2+b*x+c
def power_law(x,a,b,c):
return a*x**b+c
def scatterplot_fit(X,Y,**kwargs):
"""
Takes the X and Y lists and plots them as a 2D scatter plot
through matplotlib. Additionally, the least squares fit is
plotted throughout the datapoints.
Keyword arguments:
X -- List of the X-coordinates
Y -- List of the Y-coordinates
function -- Function to be used for curve fitting (default 'linear')
Alternatives: 'quadratic','lowess' and 'power_law'
xlabel -- Label for the X-axis (default "")
ylabel -- Label for the Y-axis (default "")
title -- Title for the plot (default "")
"""
function, xlabel, ylabel, title = kwargs.get('function','linear'), kwargs.get('xlabel',""), kwargs.get('ylabel',""), kwargs.get('title',"")
fig = plt.figure()
fig.patch.set_facecolor('white')
ax = fig.add_subplot(111)
s = ax.scatter(X,Y)
newX = np.linspace(min(X), max(X), 1000)
if function == 'linear':
popt, pcov = curve_fit(linear, X, Y)
newY = linear(newX,*popt)
a,b = popt
label = "{:.2f}".format(a)+"*x+"+"{:.2f}".format(b)
elif function == 'quadratic':
popt, pcov = curve_fit(quadratic, X, Y)
newY = quadratic(newX,*popt)
a,b,c = popt
label = "{:.2f}".format(a)+"*x**2+"+"{:.2f}".format(b)+"b*x+"+"{:.2f}".format(c)
elif function == 'lowess':
lowess = sm.nonparametric.lowess(Y, X)
newX,newY = lowess[:, 0], lowess[:, 1]
label='Lowess Fit'
elif function == 'power_law':
popt, pcov = curve_fit(power_law, X, Y)
newY = power_law(newX,*popt)
a,b,c = popt
label = "{:.2f}".format(a)+"*x**"+"{:.2f}".format(b)+"+"+"{:.2f}".format(c)
else:
print "Incorrect function specified, please use linear, quadratic, lowess or power_law"
return None
plt.plot(newX,newY,label=label)
ax.grid(True)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=2, mode="expand", borderaxespad=0.)
plt.show()
plt.close()
def heatmap_scatterplot(X,Y,Z,**kwargs):
"""
Takes the X and Y lists and plots them as a scatterplot
through matplotlib.with color coding of the points based
on the Z list.
Keyword arguments:
X -- List of the X-coordinates
Y -- List of the Y-coordinates
Z -- List of the Z-coordinates
vmin -- Minimum value to be displayed in the colorbar (default min(Z))
vmax -- Maximum value to be displayed in the colorbar (default max(Z))
edges -- The edges of each individual datapoint (default 'black')
cm -- The colormap used for the colorbar (default 'jet')
xlabel -- Label for the X-axis (default "")
ylabel -- Label for the Y-axis (default "")
zlabel -- Label for the Z-axis (default "")
title -- Title for the plot (default "")
"""
vmin, vmax, edges, cm, xlabel, ylabel, zlabel, title = kwargs.get('vmin',min(Z)), kwargs.get('vmax',max(Z)), kwargs.get('edges','black'), kwargs.get('cm','jet'), kwargs.get('xlabel',""), kwargs.get('ylabel',""), kwargs.get('zlabel',""), kwargs.get('title',"")
fig = plt.figure()
fig.patch.set_facecolor('white')
ax = fig.add_subplot(111)
s = ax.scatter(X,Y,c=Z,edgecolor=edges)
ax.grid(True)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
ax1 = fig.add_axes([0.95, 0.1, 0.01, 0.8])
cb = mpl.colorbar.ColorbarBase(ax1,norm=norm,cmap=cm,orientation='vertical')
cb.set_clim(vmin=min(Z), vmax=max(Z))
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
cb.set_label(zlabel)
ax.set_title(title)
plt.show()
plt.close()
def three_dimension_scatterplot(X,Y,Z,**kwargs):
"""
Takes the X, Y and Z lists and plots them as a 3D scatter plot
through matplotlib.
Keyword arguments:
X -- List of the X-coordinates
Y -- List of the Y-coordinates
Z -- List of the Z-coordinates
xlabel -- Label for the X-axis (default "")
ylabel -- Label for the Y-axis (default "")
zlabel -- Label for the Z-axis (default "")
title -- Title for the plot (default "")
"""
xlabel, ylabel, zlabel, title = kwargs.get('xlabel',""), kwargs.get('ylabel',""), kwargs.get('zlabel',""), kwargs.get('title',"")
fig = plt.figure()
fig.patch.set_facecolor('white')
ax = fig.add_subplot(111, projection='3d')
s = ax.scatter(X,Y,Z)
ax.grid(True)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_zlabel(zlabel)
ax.set_title(title)
plt.show()
plt.close()
def wireframe(X,Y,Z,**kwargs):
"""
Takes the X, Y and Z lists and plots them as a 3D wireframe
through matplotlib.
Keyword arguments:
X -- List of the X-coordinates
Y -- List of the Y-coordinates
Z -- List of the Z-coordinates
xlabel -- Label for the X-axis (default "")
ylabel -- Label for the Y-axis (default "")
zlabel -- Label for the Z-axis (default "")
title -- Title for the plot (default "")
"""
xlabel, ylabel, zlabel, title = kwargs.get('xlabel',""), kwargs.get('ylabel',""), kwargs.get('zlabel',""), kwargs.get('title',"")
fig = plt.figure()
fig.patch.set_facecolor('white')
ax = fig.add_subplot(111, projection='3d')
ax.plot_wireframe(X,Y,Z)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_zlabel(zlabel)
ax.set_title(title)
plt.show()
plt.close()
def surface(X,Y,Z,**kwargs):
"""
Takes the X, Y and Z lists and plots them as a 3D surface plot
through matplotlib.
Keyword arguments:
X -- List of the X-coordinates
Y -- List of the Y-coordinates
Z -- List of the Z-coordinates
xlabel -- Label for the X-axis (default "")
ylabel -- Label for the Y-axis (default "")
zlabel -- Label for the Z-axis (default "")
title -- Title for the plot (default "")
"""
xlabel, ylabel, zlabel, title = kwargs.get('xlabel',""), kwargs.get('ylabel',""), kwargs.get('zlabel',""), kwargs.get('title',"")
fig = plt.figure()
fig.patch.set_facecolor('white')
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X,Y,Z)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_zlabel(zlabel)
ax.set_title(title)
plt.show()
plt.close()
1 Answer 1
Overall the code is pretty good, and close to how I would do a library. I do have some suggestions, however:
- Follow the pep8 style guide.
- When doing numerical code always have
from __future__ import division
. - Don't put multiple commands on one line. So, for example, for your kwargs, each should be on a separate line.
- Rather than using
**kwargs
and getting the values out of the dict, you should define specific keyword arguments and give them default values. In cases where the default value needs to be computed at runtime (such asmax(Z)
) you can set the default asNone
then test if it isNone
at the beginning of the function. This is both simpler for you and much, much easier for people wanting to use your library. - I would have
*args
and**kwargs
in every function which are passed unchanged to the matplotlib plotting funcion. This is a good way to keep your code simple while still allowing access to the more advanced capabilities of the library you are using. - Your plot setup code and plot formatting code are pretty consistent across function. You can split those out into their own functions to reduce code duplication.
- I would split the fitting code (currently in the
if...elif
section) into their own functions, then access them using adict
. - If I were writing a library, I would split the fitting bits (which don't require matplotlib) into their own python file, and only keep the plotting-specific bits in this file.
- Also, if I was writing a library, I would have an optional
ax
argument for each plotting function that lets you pass an axes object. If that happens, then the figure creation,plt.show()
, andplt.close()
parts aren't called. This allows you to use these functions with subplots or make additional formatting changes before showing it, or just save the figure to a file without showing it at all. - If I was writing a library, I would also make the face color a keyword argument with
white
being the default value. - I would probably abstract the scatterplot bits of
scatterplot_fit
andheatmap_scatterplot
into ascatterplot
function. With the ability mentioned above to pass an axes object to the plotting functions, yourscatterplot_fit
andheatmap_scatterplot
would be able to create an axes, pass it to thescatterplot
function for plotting the scatterplot, then do their additional stuff with the axes afterwards. - This code won't do anything when run as a script so it doesn't need a shebang.
- I would document the first three functions.
Explore related questions
See similar questions with these tags.