5
\$\begingroup\$

A Python script to generate an image with a given number of digits from MNIST data on a single row.

Via arguments, the following can be specified (all optional, with defaults if necessary):

  • width of resulting image(s)
  • minimum margin between digits
  • maximum margin between digits
  • number of MNIST digits per image
  • which characters to extract from MNIST
  • number of images to generate
  • which directory to save the images to

Possible application is to generate images to train an OCR system for digits, or some form of CAPTCHA.

#!/usr/bin/env python
# coding: utf-8
import os
import pickle
import argparse
import numpy as np
def load_data_and_dict():
 """Checks if numpy format images and idx dict exist
 if so:
 load
 if not:
 prepare and load
 :return: tuple of:
 <numpy array> mnist images
 <numpy array> dictionary
 """
 # check if data and dicts exist else download and generate
 mnist_dir = './data'
 mnist_images_fn = "./mnist_images"
 mnist_images_fn_loc = os.path.join(mnist_dir, mnist_images_fn + ".npy")
 mnist_idx_dict_fn = "./mnist_idx_dict"
 mnist_idx_dict_fn_loc = os.path.join(mnist_dir, mnist_idx_dict_fn + ".pickle")
 if os.path.isfile(mnist_images_fn_loc) and os.path.isfile(mnist_idx_dict_fn_loc):
 print("Found image data, loading...", end="")
 images = np.load(mnist_images_fn_loc)
 with open(mnist_idx_dict_fn_loc, 'rb') as handle:
 idx_dict = pickle.load(handle)
 print("DONE")
 else:
 print("Not all image data found, preparing...")
 import gzip
 import shutil
 import urllib.request
 mnist_urls = ["http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
 "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"]
 # download if not exists
 os.makedirs(mnist_dir, exist_ok=True)
 # Download the file if it does not exist
 for mnist_url in mnist_urls:
 download_filename = os.path.join(mnist_dir, os.path.basename(mnist_url))
 if not os.path.isfile(download_filename):
 print(f"Downloading: {mnist_url}")
 urllib.request.urlretrieve(mnist_url, download_filename)
 # extract zip files if necessary
 zip_files = [os.path.join(mnist_dir, x) for x in os.listdir(mnist_dir) if x.endswith('.gz')]
 targets = [x[:-3] for x in zip_files]
 for idx, zip_file in enumerate(zip_files):
 if not os.path.isfile(targets[idx]):
 print(f"Unzipping {zip_file} to {targets[idx]}")
 with gzip.open(zip_file, 'rb') as f_in:
 with open(targets[idx], 'wb') as f_out:
 shutil.copyfileobj(f_in, f_out)
 from mnist import MNIST
 mndata = MNIST(mnist_dir)
 images, labels = mndata.load_training()
 images = np.array(images)
 labels = np.array(labels)
 idx_dict = {}
 for i in range(10):
 idx_dict[i] = np.where(labels == i)[0]
 # save data for future use
 print(f"Saving images to {mnist_images_fn_loc}...", end="")
 np.save(mnist_images_fn_loc[:-4], images)
 print("DONE")
 print(f"Saving idx dict to {mnist_idx_dict_fn_loc}...", end="")
 with open(mnist_idx_dict_fn_loc, 'wb') as handle:
 pickle.dump(idx_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
 print("DONE")
 return images, idx_dict
def create_digit_sequence(n_arr, width, margin_min, margin_max, images, id_dict):
 if margin_max < margin_min:
 return "Maximum margin must be larger or equal to minimum margin"
 image_size = 28
 res = np.zeros((image_size, width, 3))
 extra_margin = width - (len(n_arr) - 1) * margin_min - len(n_arr) * image_size
 if extra_margin < 0:
 return "Current given minimum margin would result in exceeded width"
 start_idx = 0
 for x in n_arr:
 img = images[np.random.choice(id_dict[int(x)])].reshape((28, 28, 1))
 res[:, start_idx:start_idx+28, :] = img
 start_idx += 28
 additional_margin = np.random.randint(0, margin_max - margin_min)
 additional_margin = np.min((extra_margin, additional_margin))
 extra_margin -= additional_margin
 start_idx += additional_margin
 return res
if __name__ == "__main__":
 parser = argparse.ArgumentParser(description='Generate images from MNIST images for OCR training purposes.')
 parser.add_argument('-w', '--width', default='200',
 help='Width of the resulting image')
 parser.add_argument('-i', '--minmargin', default='0',
 help='Minimum margin between MNIST characters')
 parser.add_argument('-a', '--maxmargin', default='100',
 help='Maximum margin between MNIST characters')
 parser.add_argument('-l', '--strlen', default='5',
 help='number of characters per string')
 parser.add_argument('-s', '--numberstring',
 help='string of numbers ')
 parser.add_argument('-n', '--genn', default='10',
 help='number of images to generate')
 parser.add_argument('-o', '--outputdir', default='./images',
 help='output directory for generated images')
 args = parser.parse_args()
 # parse args to int if necessary
 mnist_string = args.numberstring
 min_margin = int(args.minmargin)
 max_margin = int(args.maxmargin)
 width = int(args.width)
 n = int(args.genn)
 str_len = int(args.strlen)
 char_string = args.numberstring
 # load images and idx data
 images, id_dict = load_data_and_dict()
 # create output dir if not exist
 out_dir = args.outputdir
 os.makedirs(out_dir, exist_ok=True)
 # main program loop
 for i in range(n):
 # generate a new string per loop if one wasn't provided
 if char_string is None:
 gen_arr = np.random.randint(0, 9, str_len)
 else:
 gen_arr = [int(x) for x in char_string]
 # run the generator
 mnist_ocr_image = create_digit_sequence(gen_arr, width, min_margin, max_margin, images, id_dict)
 np.save(os.path.join(out_dir, "mnist_ocr_image_{:0>6}".format(i)), mnist_ocr_image)
 print(f"Saved {n} images in {out_dir}")
AlexV
7,3532 gold badges24 silver badges47 bronze badges
asked Dec 17, 2019 at 13:01
\$\endgroup\$

1 Answer 1

2
\$\begingroup\$

As a general feedback, the script looks quite good. I will nevertheless share a few of my thoughts with you.

shebang

Since you are using Python 3, the initial shebang should be #!/usr/bin/env python3. Otherwise it will depend on the system which interpreter is used to execute the script once the file is marked as executable.

Documentation

Only load_data_and_dict() is documented using a docstring. The rest of the code is not. It might be a good idea to document the scripts behavior (e.g. directories/files expected/generated by the code) on a module level. To quote PEP 257:

The docstring of a script (a stand-alone program) should be usable as its "usage" message, printed when the script is invoked with incorrect or missing arguments (or perhaps with a "-h" option, for "help"). Such a docstring should document the script's function and command line syntax, environment variables, and files. Usage messages can be fairly elaborate (several screens full) and should be sufficient for a new user to use the command properly, as well as a complete quick reference to all options and arguments for the sophisticated user.

Since you are working in the scientific Python stack (numpy, ...), it might also be worth to have a look at numpydoc, the style used for the numpy/scipy/... documentation. It's quite expressive and ready to be used to automatically generate documentation using tools like Sphinx.

Error reporting

Returning strings to report errors is not a particularly robust idea. Instead consider raising and catching exceptions to signal something went wrong. In your case, a ValueError seems like an appropriate choice. E.g.

raise ValueError("Maximum margin must be larger or equal to minimum margin")

Also, extra_margin should be checked before creating the res image, since it's a useless memory allocation in case the extra_margin check won't pass.

Magic values

Sometimes 28 is used as "magic value", instead of referring to what is defined as image_size. Using image_size everywhere would be clearer and more consistent.

Randomness

Maybe you should think about including a possibility to manually select the seed value for the RNG. This will allow you to create reproducible, pseudo-random datasets.

Command-line interface

argparse supports the type keyword argument, which would allow you to rewrite, e.g.

parser.add_argument('-i', '--minmargin', default='0',
 help='Minimum margin between MNIST characters')
parser.add_argument('-a', '--maxmargin', default='100',
 help='Maximum margin between MNIST characters')
...
min_margin = int(args.minmargin)
max_margin = int(args.maxmargin)

as

parser.add_argument('-i', '--minmargin', default='0',
 help='Minimum margin between MNIST characters', type=int)
parser.add_argument('-a', '--maxmargin', default='100',
 help='Maximum margin between MNIST characters', type=int)
...
min_margin = args.minmargin # maybe these even become unnecessary 
max_margin = args.maxmargin
answered Jan 14, 2020 at 10:24
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.