TFDS now supports the Croissant 🥐 format! Read the documentation to know more.

Training a neural network on MNIST with Keras

This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook
importtensorflowastf
importtensorflow_datasetsastfds
2025年10月04日 11:16:16.699239: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1759576576.724018 11705 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759576576.732199 11705 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1759576576.751752 11705 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1759576576.751774 11705 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1759576576.751777 11705 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1759576576.751780 11705 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

Step 1: Create your input pipeline

Start by building an efficient input pipeline using advices from:

Load a dataset

Load the MNIST dataset with the following arguments:

  • shuffle_files=True: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.
  • as_supervised=True: Returns a tuple (img, label) instead of a dictionary {'image': img, 'label': label}.
(ds_train, ds_test), ds_info = tfds.load(
 'mnist',
 split=['train', 'test'],
 shuffle_files=True,
 as_supervised=True,
 with_info=True,
)
2025年10月04日 11:16:21.222353: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Build a training pipeline

Apply the following transformations:

defnormalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
 return tf.cast(image, tf.float32) / 255., label
ds_train = ds_train.map(
 normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Build an evaluation pipeline

Your testing pipeline is similar to the training pipeline with small differences:

  • You don't need to call tf.data.Dataset.shuffle.
  • Caching is done after batching because batches can be the same between epochs.
ds_test = ds_test.map(
 normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Step 2: Create and train the model

Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.

model = tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(128, activation='relu'),
 tf.keras.layers.Dense(10)
])
model.compile(
 optimizer=tf.keras.optimizers.Adam(0.001),
 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
 metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(
 ds_train,
 epochs=6,
 validation_data=ds_test,
)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
 super().__init__(**kwargs)
Epoch 1/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 4ms/step - loss: 0.6073 - sparse_categorical_accuracy: 0.8348 - val_loss: 0.1852 - val_sparse_categorical_accuracy: 0.9468
Epoch 2/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.1737 - sparse_categorical_accuracy: 0.9514 - val_loss: 0.1307 - val_sparse_categorical_accuracy: 0.9619
Epoch 3/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.1192 - sparse_categorical_accuracy: 0.9661 - val_loss: 0.1063 - val_sparse_categorical_accuracy: 0.9697
Epoch 4/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0948 - sparse_categorical_accuracy: 0.9727 - val_loss: 0.0951 - val_sparse_categorical_accuracy: 0.9723
Epoch 5/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9786 - val_loss: 0.0892 - val_sparse_categorical_accuracy: 0.9738
Epoch 6/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0599 - sparse_categorical_accuracy: 0.9830 - val_loss: 0.0794 - val_sparse_categorical_accuracy: 0.9772
<keras.src.callbacks.history.History at 0x7fbc39d63880>

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025年10月04日 UTC.