I quite often plot graphs looking at how some property or function varies with different parameters. Normally, I find the analysis code can be written fairly succinctly and separated into a suitable number of functions. However, when I come to plot the graphs my code tends to come unmanageably long.
Below is a basic example, while the code for generating the graphic is modest by my standards I hope it highlights the issue. There will often be several more if
statements for adding different things to a specific axis.
Do you have any pointers for keeping plot generation succinct, or is it inevitable when building custom plots?
def get_data(x, order, scale):
'''Generates data for this example.'''
y = scale * (x**order)
return y
# generate the example data
orders = [2, 2.2, 2.3]
scales = [1, 2]
scale_grid, order_grid = np.meshgrid(scales, orders)
parameters = list(zip(order_grid.ravel(), scale_grid.ravel()))
my_x = np.arange(0, 10.1, 0.1)
my_ys = []
for ps in parameters:
my_ys.append(get_data(my_x, *ps))
###############################################
# generate the graph
my_colors = {0:'r', 1:'b', 2:'g'}
fig, ax_rows = plt.subplots(3, ncols=2)
for i, ax_row in enumerate(ax_rows):
# plot the graphs
ax_row[0].plot(my_x, my_ys[i*2], lw=2, color=my_colors[i])
ax_row[1].plot(my_x, my_ys[i*2+1], lw=2, color=my_colors[i])
# format the axes
ax_row[0].set_ylabel('y (unit)')
plt.setp(ax_row[1].get_yticklabels(), visible=False)
for ax in ax_row:
ax.set_ylim(0, 500)
ax.set_xlim(0, 10)
if i!=2:
plt.setp(ax.get_xticklabels(), visible=False)
if i==2:
ax.set_xlabel('x (unit)')
# add the text displaying parameters
ax_row[0].text(0.03, 0.95,'Scale: {0:.2f}\nOrder:
{1:.2f}'.format(parameters[i*2][1],
parameters[i*2][0]),
transform=ax_row[0].transAxes,
verticalalignment='top')
ax_row[1].text(0.03, 0.95,'Scale: {0:.2f}\nOrder:
{1:.2f}'.format(parameters[i*2+1][1],
parameters[i*2+1][0]),
transform=ax_row[1].transAxes,
verticalalignment='top')
fig.set_size_inches(5, 5)
The output graphic is here. Example multi-panel graphic
1 Answer 1
I'm not a guru of matplotlib but I'll show how I would approach the problem. Maybe you could get something useful out of it. Also, I'm gonna review all the code, not only the plotting part.
get_data
:
def get_data(x, order, scale): '''Generates data for this example.''' y = scale * (x**order) return y
- For docstrings use triple double quotes:
"""Your docstring"""
. - There is no need to put so many blank lines inside such a small function.
- Exponentiation has a higher precedence than multiplication, that means you can remove parentheses:
scale * x ** order
. - Variable
y
is redundant, just write:return scale * x ** order
. - You can use type hints to help readers understand what types of data your function operates with. Also some IDEs are capable of analyzing them and will highlight places where there are inconsistencies between what is supplied to a function and what that function expected to get.
That gives us:
def get_data(x: np.ndarray,
order: float,
scale: float) -> np.ndarray:
"""Generates data for this example."""
return scale * x ** order
Generating example data:
# generate the example data orders = [2, 2.2, 2.3] scales = [1, 2] scale_grid, order_grid = np.meshgrid(scales, orders) parameters = list(zip(order_grid.ravel(), scale_grid.ravel())) my_x = np.arange(0, 10.1, 0.1) my_ys = [] for ps in parameters: my_ys.append(get_data(my_x, *ps))
- Obtaining
parameters
by usingnp.meshgrid
,np.ravel
andzip
doesn't look good.np.meshgrid
will generate 2D arrays which is unnecessary. You can useitertools.product
to get a Cartesian product of input parameters:list(itertools.product(orders, scales))
. Docs for
np.arange
warn:When using a non-integer step, such as 0.1, the results will often not be consistent. It is better to use
numpy.linspace
for these cases.So, instead you should have
my_x = np.linspace(0, 10, 101)
.- There is no need to have blank line after the for-loop first line.
- We don't need to keep all the data in the
my_ys
. We can calculate it on the fly (see below).
That gives us:
import itertools
from typing import Iterable, Iterator, Tuple
def get_my_ys(x: np.ndarray,
parameters: Iterable[Tuple[float, float]]
) -> Iterator[np.ndarray]:
"""Yields data for different parameters"""
for order, scale in parameters:
yield get_data(x, order=order, scale=scale)
...
orders = [2, 2.2, 2.3]
scales = [1, 2]
parameters = list(itertools.product(orders, scales))
my_x = np.linspace(0, 10, 101)
my_ys = get_my_ys(my_x, parameters)
Probably, you would want to extend it later for a variable number of parameters.
Generating the graph:
- First of all, why is
my_colors
a dict (my_colors = {0:'r', 1:'b', 2:'g'}
)? When seeing a dict which keys are 0, 1, 2, ... it makes me think that it probably should be a list instead. In
fig, ax_rows = plt.subplots(3, ncols=2)
it looks inconsistent that you specify keywordncols
but notnrows
.3
and2
are, in fact, lengths of parametersorders
andscales
, you should tie them together. And, according to docs you could also specifysharex
andsharey
asTrue
/'all'
. So, you wouldn't have to write:for i, ax_row in enumerate(ax_rows): ... for ax in ax_row: ax.set_ylim(0, 500) ax.set_xlim(0, 10)
for each subplot.
# plot the graphs ax_row[0].plot(my_x, my_ys[i*2], lw=2, color=my_colors[i]) ax_row[1].plot(my_x, my_ys[i*2+1], lw=2, color=my_colors[i])
Several issues here. First of all, if you change the number of parameters in
scale
andorder
, this won't work as intended. Next, there is a code duplication on these two lines. And this indexing ofmy_ys
just doesn't feel right. Ideally, this should look like:for ax, y, color in zip(fig.axes, my_ys, colors): ax.plot(x, y, lw=linewidth, color=color)
Note the
fig.axes
. This will give you a list of all axes in the figure.Again, this
plt.setp(ax_row[1].get_yticklabels(), visible=False)
will remove labels only in the second column. But what if you have more parameters and therefore more columns? Actually, we don't need these lines if you are going to usesharex
andsharey
when creating the figure. It will take care of them automatically.Instead of checking the indices of the subplots to add labels for x-axis, I suggest simply iterate over the last row of
ax_rows
returned fromplt.subplots
:for ax in ax_rows[-1]: ax.set_xlabel(xlabel)
Though, we should be careful with the returned type of
ax_rows
, because, as docs say, itcan be either a single Axes object or an array of Axes objects
In order to get all the time an array, we should specify
squeeze=False
in theplt.subplots
call.ax_row[0].text(0.03, 0.95,'Scale: {0:.2f}\nOrder: {1:.2f}'.format(parameters[i*2][1], parameters[i*2][0]), transform=ax_row[0].transAxes, verticalalignment='top') ax_row[1].text(0.03, 0.95,'Scale: {0:.2f}\nOrder: {1:.2f}'.format(parameters[i*2+1][1], parameters[i*2+1][0]), transform=ax_row[1].transAxes, verticalalignment='top')
Same problems here: code duplication, clumsy indexing, and it won't work if you add more input parameters in
orders
orscales
. Here is how I suggest to generate labels:label_template = 'Scale: {1:.2f}\nOrder: {0:.2f}' labels = itertools.starmap(label_template.format, parameters)
Here I use
itertools.starmap
to supply tuples of parameters to thestr.format
method of thelabel_template
.In the end, plotting would look like something like this:
for ax, y, label, color in zip(fig.axes, my_ys, labels, colors): ax.plot(x, y, lw=linewidth, color=color) ax.text(s=label, transform=ax.transAxes, **text_properties)
where
text_properties
is a dict that would keep all the properties like positions, alignment, etc.
Revised code:
import itertools
from functools import partial
from typing import (Any,
Dict,
Iterable,
Iterator,
List,
Tuple)
import matplotlib.pyplot as plt
import numpy as np
TEXT_PROPERTIES = dict(x=0.03,
y=0.95,
verticalalignment='top')
def main():
my_colors = ['r', 'b', 'g']
orders = [2, 2.2, 2.3]
scales = [1, 2]
parameters = list(itertools.product(orders, scales))
my_x = np.linspace(0, 10, 101)
my_ys = get_my_ys(my_x, parameters)
label_template = 'Scale: {1:.2f}\nOrder: {0:.2f}'
labels = itertools.starmap(label_template.format, parameters)
colors = replicate_items(my_colors, times=len(scales))
plot(x=my_x,
ys=my_ys,
nrows=len(orders),
ncols=len(scales),
labels=labels,
colors=colors,
xlim=[0, 10],
ylim=[0, 500],
xlabel='x (unit)',
ylabel='y (unit)')
plt.show()
def get_my_ys(x: np.ndarray,
parameters: Iterable[Tuple[float, float]]
) -> Iterator[np.ndarray]:
"""Yields data for different parameters"""
for order, scale in parameters:
yield get_data(x, order=order, scale=scale)
def get_data(x: np.ndarray,
order: float,
scale: float) -> np.ndarray:
"""Generates data for this example."""
return scale * x ** order
def replicate_items(seq: Iterable[Any],
times: int) -> Iterable[Any]:
"""replicate_items('ABC', 2) --> A A B B C C"""
repeat = partial(itertools.repeat, times=times)
repetitions = map(repeat, seq)
yield from itertools.chain.from_iterable(repetitions)
def plot(x: np.ndarray,
ys: Iterable[np.ndarray],
nrows: int,
ncols: int,
labels: Iterable[str],
colors: Iterable[str],
xlim: List[float],
ylim: List[float],
xlabel: str,
ylabel: str,
text_properties: Dict[str, Any] = None,
linewidth: float = 2,
fig_size: Tuple[float, float] = (5, 5)) -> plt.Figure:
"""TODO: add docstring"""
if text_properties is None:
text_properties = TEXT_PROPERTIES
fig, ax_rows = plt.subplots(nrows=nrows,
ncols=ncols,
sharex='all',
sharey='all',
squeeze=False)
fig.set_size_inches(fig_size)
plt.xlim(xlim)
plt.ylim(ylim)
for ax, y, label, color in zip(fig.axes, ys, labels, colors):
ax.plot(x, y, lw=linewidth, color=color)
ax.text(s=label,
transform=ax.transAxes,
**text_properties)
for ax in ax_rows[:, 0]:
ax.set_ylabel(ylabel)
for ax in ax_rows[-1]:
ax.set_xlabel(xlabel)
return fig
if __name__ == '__main__':
main()
I'm sure that there are other things that could be improved. But this should get you going.
-
\$\begingroup\$ Hi Georgy, Thanks for your comprehensive answer! I shall review when I have some time and accept it, There seems a lot of useful content here and a lot I can clearly do too improve! One question you had:
my_colors
is a dictionary because usually I will have a key for each row (e.g., 'Small', 'Medium', 'Large'). \$\endgroup\$FChm– FChm2019年02月13日 11:41:01 +00:00Commented Feb 13, 2019 at 11:41
Explore related questions
See similar questions with these tags.