I'd like to plot an animation of Lissajous curves using Python and matplotlib's animate
library. I really do not have a lot of experience with Python, so rather than performance increasements, I'm looking for best practices to improve (and/or shorten) my code.
The following code produces a .gif
file when evaluated as a jupyter-lab
cell:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.animation import PillowWriter
x_data = []
y_data = []
max_range = 1.2
f1 = 3 # sets the frequency for the horizontal motion
f2 = 5 # sets the frequency for the vertical motion
d1 = 0.0 # sets the phase shift for the horizontal motion
d2 = 0.5 # sets the phase shift for the vertical motion
delta1 = d1 * np.pi # I define the phase shift like this in order to use
delta2 = d2 * np.pi # ...d1 and d2 in the export file name
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(4,4), gridspec_kw={'width_ratios': [6, 1], 'height_ratios': [1, 6]})
for i in [ax1, ax2, ax3, ax4]:
i.set_yticklabels([]) # in order to remove the ticks and tick labels
i.set_xticklabels([])
i.set_xticks([])
i.set_yticks([])
i.set_xlim(-max_range, max_range)
i.set_ylim(-max_range, max_range)
ax2.set_visible(False)
line, = ax3.plot(0, 0) # line plot in the lower left
line2 = ax3.scatter(0, 0) # moving dot in the lower left
linex = ax1.scatter(0, 0) # moving dot on top
liney = ax4.scatter(0, 0) # moving dot on the right
def animation_frame(i):
ax1.clear() # I tried to put this in a loop like this:
ax1.set_yticklabels([]) # for i in [ax1,ax3,ax4]:
ax1.set_xticklabels([]) # i.clear()
ax1.set_xticks([]) # (... etc.)
ax1.set_yticks([]) # but this didn't work
ax3.clear()
ax3.set_yticklabels([])
ax3.set_xticklabels([])
ax3.set_xticks([])
ax3.set_yticks([])
ax4.clear()
ax4.set_yticklabels([])
ax4.set_xticklabels([])
ax4.set_xticks([])
ax4.set_yticks([])
ax1.set_xlim(-max_range, max_range) # after ax.clear() I apparently have to re-set these
ax3.set_xlim(-max_range, max_range)
ax3.set_ylim(-max_range, max_range)
ax4.set_ylim(-max_range, max_range)
x_data.append(np.sin(i * f1 + delta1)) # for the line plot
y_data.append(np.sin(i * f2 + delta2))
x_inst = np.sin(i * f1 + delta1) # for the scatter plot
y_inst = np.sin(i * f2 + delta2)
line, = ax3.plot(x_data, y_data)
line2 = ax3.scatter(x_inst, y_inst)
linex = ax1.scatter(x_inst, 0)
liney = ax4.scatter(0, y_inst)
fig.canvas.draw()
transFigure = fig.transFigure.inverted() # in order to draw over 2 subplots
coord1 = transFigure.transform(ax1.transData.transform([x_inst, 0]))
coord2 = transFigure.transform(ax3.transData.transform([x_inst, y_inst]))
my_line1 = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]), transform=fig.transFigure, linewidth=1, c='gray', alpha=0.5)
coord1 = transFigure.transform(ax3.transData.transform([x_inst, y_inst]))
coord2 = transFigure.transform(ax4.transData.transform([0, y_inst]))
my_line2 = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]), transform=fig.transFigure, linewidth=1, c='gray', alpha=0.5)
fig.lines = my_line1, my_line2, # moving vertical and horizontal lines
return line, line2, linex, liney
animation = FuncAnimation(fig, func=animation_frame, frames=np.linspace(0, 4*np.pi, num=800, endpoint=True), interval=1000)
animation.save('lissajous_{0}_{1}_{2:.2g}_{3:.2g}.gif'.format(f1,f2,d1,d2), writer='pillow', fps=50, dpi=200)
# This takes quite long, but since I'd like to have a smooth, slow animation,
# ...I'm willing to accept a longer execution time.
1 Answer 1
For demonstration purposes I disregard the .gif generation, and especially disregard Jupyter. Notebooks tend to produce a scope swamp. System parameters? Throw them in the swamp. Loops, matplotlib artists, initialization code? In the swamp.
plt.subplots()
is a convenience wrapper, and not a helpful one here. It produces a useless ax2
that you then need to hide. Instead, just use a GridSpec
so that you can pick the locations you want while still controlling ratios.
animation_frame
recreates your artists from scratch every time! All of those calls to plot()
and scatter()
need to be removed, and instead you need to perform data updates.
You're missing my_line1
, my_line2
from your tuple of returned updated artists.
Your axis and artist names are very difficult to understand; those need better names.
A one-second (1000 ms) update rate is quite slow; for the interactive plot let's speed it up a lot.
All together,
import typing
import numpy as np
import matplotlib
import matplotlib.animation
import matplotlib.pyplot as plt
class Lissajous(typing.NamedTuple):
fig: plt.Figure
ax_top: plt.Axes
ax_mid: plt.Axes
ax_right: plt.Axes
mid_curve: matplotlib.lines.Line2D
mid_scatter: matplotlib.collections.PathCollection
x_scatter: matplotlib.collections.PathCollection
y_scatter: matplotlib.collections.PathCollection
vert_line: matplotlib.lines.Line2D
horz_line: matplotlib.lines.Line2D
max_range: float
f1: float # frequency for the horizontal motion
f2: float # frequency for the vertical motion
d1: float # phase shift for the horizontal motion
d2: float # phase shift for the vertical motion
delta1: float # horizontal phase (radians)
delta2: float # vertical phase (radians)
x_data: list[float] = []
y_data: list[float] = []
@classmethod
def new(
cls,
max_range: float = 1.2,
f1: float = 3.,
f2: float = 5.,
d1: float = 0.,
d2: float = 0.5,
) -> typing.Self:
fig = plt.figure()
grid = matplotlib.gridspec.GridSpec(figure=fig, nrows=2, ncols=2, width_ratios=(6, 1), height_ratios=(1, 6))
ax_top = fig.add_subplot(grid[0,0])
ax_mid = fig.add_subplot(grid[1,0])
ax_right = fig.add_subplot(grid[1,1])
for ax in (ax_top, ax_mid, ax_right):
ax.set_yticklabels(())
ax.set_xticklabels(())
ax.set_xticks(())
ax.set_yticks(())
ax.set_xlim(-max_range, max_range)
ax.set_ylim(-max_range, max_range)
vert_line = matplotlib.lines.Line2D((),(), transform=fig.transFigure, linewidth=1, c='gray', alpha=0.5)
horz_line = matplotlib.lines.Line2D((),(), transform=fig.transFigure, linewidth=1, c='gray', alpha=0.5)
fig.add_artist(vert_line)
fig.add_artist(horz_line)
return cls(
fig=fig, ax_top=ax_top, ax_mid=ax_mid, ax_right=ax_right,
mid_curve=ax_mid.plot((), ())[0],
mid_scatter=ax_mid.scatter((), ()),
x_scatter=ax_top.scatter((), ()),
y_scatter=ax_right.scatter((), ()),
vert_line=vert_line, horz_line=horz_line, max_range=max_range,
f1=f1, f2=f2, d1=d1, d2=d2, delta1=d1*np.pi, delta2=d2*np.pi,
)
def make_animation(self, n_frames: int = 800, interval: int = 1_000) -> matplotlib.animation.FuncAnimation:
return matplotlib.animation.FuncAnimation(
fig=self.fig, func=self.update, interval=interval,
frames=np.linspace(start=0, stop=4*np.pi, num=n_frames, endpoint=True),
)
def update(self, t: float) -> tuple[matplotlib.artist.Artist, ...]:
x_inst = np.sin(t*self.f1 + self.delta1)
y_inst = np.sin(t*self.f2 + self.delta2)
self.x_data.append(x_inst)
self.y_data.append(y_inst)
self.mid_curve.set_xdata(self.x_data)
self.mid_curve.set_ydata(self.y_data)
self.mid_scatter.set_offsets((x_inst, y_inst))
self.x_scatter.set_offsets((x_inst, 0))
self.y_scatter.set_offsets((0, y_inst))
# This cannot be cached; it becomes invalid on resize
trans_figure = self.fig.transFigure.inverted()
ax, ay = trans_figure.transform(self.ax_top.transData.transform((x_inst, 0)))
bx, by = trans_figure.transform(self.ax_mid.transData.transform((x_inst, y_inst)))
self.vert_line.set_xdata((ax, bx))
self.vert_line.set_ydata((ay, by))
ax, ay = bx, by
bx, by = trans_figure.transform(self.ax_right.transData.transform((0, y_inst)))
self.horz_line.set_xdata((ax, bx))
self.horz_line.set_ydata((ay, by))
return self.mid_curve, self.mid_scatter, self.x_scatter, self.y_scatter, self.vert_line, self.horz_line
def main() -> None:
lissajous = Lissajous.new()
anim = lissajous.make_animation(interval=30)
plt.show()
if __name__ == '__main__':
main()
Explore related questions
See similar questions with these tags.