I am playing with the cifar-10 dataset (available here) and for now I would like to plot one of the images of a batch.
The images are represented as vectors when I get them from pickle:
From the cifar-10 documentation:
The first 1024 entries (of an image) contain the red channel values, the next 1024 the green, and the final 1024 the blue. The image is stored in row-major order, so that the first 32 entries of the array are the red channel values of the first row of the image.
What I came up with to plot an image is this:
import numpy as np
import matplotlib.pyplot as plt
# get the dataset
a = unpickle('./cifar-10/data_batch_1')
# get the first image
img = np.array(a[b'data'][0])
# transform it to a 3 x 1024 array, one row per color channel
# and transpose it to a 1024 x 3 array, one row per rgb pixel
img = img.reshape(3, 1024).T
# reshape it so we can plot it as a 32 x 32 image with 3 color channels
img = img.reshape(32, 32, 3)
# plot
plt.imshow(img)
plt.show()
It's my first attempt at matrix manipulation so even if this is concise, I feel like it could be simpler. What do you guys think?
1 Answer 1
One alternative is to transform it to the right shapes, then use moveaxis
. I don't know how much simpler this is than what you've got, I guess it avoids one reshaping operation.
img = img.reshape(3, 32, 32)
img = np.moveaxis(img, 0, -1) # move the first axis to the end
or as a one-liner:
img = np.moveaxis(img.reshape(3, 32, 32), 0, -1)
Note that moveaxis
returns a view, meaning that no data is copied.