I'm trying to plot the first 5 images from a dataset.
The function to do so is:
def plotImages(image_arr):
fig, axes = plt.subplot(1, 5, figsize=(20,20))
axes = axes.flatten()
for img, ax in zip(image_arr, axes):
ax.imshow(img)
plt.tight_layout()
plt.show()
But when I call the function with the following:
plotImages(sample_training_images[:5])
..it throws me tthe error:
ValueError: Illegal argument(s) to subplot: (1, 5)
Here is the full code before the fact:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import matplotlib.pyplot as plt
import numpy as np
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)
URL = r'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
zip_dir = tf.keras.utils.get_file('cats_and_dogs_filtered.zip', origin=URL, extract=True)
zip_dir_base = os.path.dirname(zip_dir)
base_dir = os.path.join(os.path.dirname(zip_dir), 'cats_and_dogs_filtered')
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))
num_cats_val = len(os.listdir(validation_cats_dir))
num_dogs_val = len(os.listdir(validation_dogs_dir))
total_train = num_cats_tr+num_dogs_tr
total_validation = num_cats_val+num_dogs_val
print(total_train)
print(total_validation)
BATCH_SIZE = 100
IMG_SHAPE = 150
train_image_generator = ImageDataGenerator(rescale=1./255)
validation_image_generator = ImageDataGenerator(rescale=1./255)
train_data_gen = train_image_generator.flow_from_directory(batch_size=BATCH_SIZE,
directory=train_dir, shuffle=True, target_size=(IMG_SHAPE,IMG_SHAPE), class_mode='binary')
val_data_gen = validation_image_generator.flow_from_directory(batch_size=BATCH_SIZE,
directory=validation_dir, shuffle=True, target_size=(IMG_SHAPE,IMG_SHAPE), class_mode='binary')
sample_training_images, _ = next(train_data_gen)
def plotImages(image_arr):
fig, axes = plt.subplot(1, 5, figsize=(20,20))
axes = axes.flatten()
for img, ax in zip(image_arr, axes):
ax.imshow(img)
plt.tight_layout()
plt.show()
plotImages(sample_training_images[:5])
1 Answer 1
from the docs of subplot:
Call signatures:
subplot(nrows, ncols, index, **kwargs)
subplot(pos, **kwargs)
subplot(ax)
So it looks like you aren't providing the index argument to the function.
And so, with subplot you'd have to take care of the indexing individually for each plot.
I believe, subplots
, notice the additional s at the end, is what you're looking for.
https://matplotlib.org/3.2.2/api/_as_gen/matplotlib.pyplot.subplots.html
With that, you could simple do fig, axes = plt.subplots(1,5)