0

I'm trying to plot the first 5 images from a dataset.

The function to do so is:

def plotImages(image_arr):
 fig, axes = plt.subplot(1, 5, figsize=(20,20))
 axes = axes.flatten()
 for img, ax in zip(image_arr, axes):
 ax.imshow(img)
 plt.tight_layout()
 plt.show()

But when I call the function with the following:

plotImages(sample_training_images[:5])

..it throws me tthe error:

ValueError: Illegal argument(s) to subplot: (1, 5)

Here is the full code before the fact:

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import matplotlib.pyplot as plt
import numpy as np
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)
URL = r'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
zip_dir = tf.keras.utils.get_file('cats_and_dogs_filtered.zip', origin=URL, extract=True)
zip_dir_base = os.path.dirname(zip_dir)
base_dir = os.path.join(os.path.dirname(zip_dir), 'cats_and_dogs_filtered')
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))
num_cats_val = len(os.listdir(validation_cats_dir))
num_dogs_val = len(os.listdir(validation_dogs_dir))
total_train = num_cats_tr+num_dogs_tr
total_validation = num_cats_val+num_dogs_val
print(total_train)
print(total_validation)
BATCH_SIZE = 100
IMG_SHAPE = 150
train_image_generator = ImageDataGenerator(rescale=1./255)
validation_image_generator = ImageDataGenerator(rescale=1./255)
train_data_gen = train_image_generator.flow_from_directory(batch_size=BATCH_SIZE, 
directory=train_dir, shuffle=True, target_size=(IMG_SHAPE,IMG_SHAPE), class_mode='binary')
val_data_gen = validation_image_generator.flow_from_directory(batch_size=BATCH_SIZE, 
directory=validation_dir, shuffle=True, target_size=(IMG_SHAPE,IMG_SHAPE), class_mode='binary')
sample_training_images, _ = next(train_data_gen)
def plotImages(image_arr):
 fig, axes = plt.subplot(1, 5, figsize=(20,20))
 axes = axes.flatten()
 for img, ax in zip(image_arr, axes):
 ax.imshow(img)
 plt.tight_layout()
 plt.show()
plotImages(sample_training_images[:5])
Zephyr
12.6k89 gold badges53 silver badges92 bronze badges
asked Jul 10, 2020 at 5:50

1 Answer 1

1

from the docs of subplot:

Call signatures:

subplot(nrows, ncols, index, **kwargs)
subplot(pos, **kwargs)
subplot(ax)

So it looks like you aren't providing the index argument to the function.
And so, with subplot you'd have to take care of the indexing individually for each plot.

I believe, subplots, notice the additional s at the end, is what you're looking for. https://matplotlib.org/3.2.2/api/_as_gen/matplotlib.pyplot.subplots.html

With that, you could simple do fig, axes = plt.subplots(1,5)

answered Jul 10, 2020 at 6:22
0

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.