Run a calculation on a Cloud TPU VM using JAX

This document provides a brief introduction to working with JAX and Cloud TPU.

Before you begin

Before running the commands in this document, you must create a Google Cloud account, install the Google Cloud CLI, and configure the gcloud command. For more information, see Set up the Cloud TPU environment.

Create a Cloud TPU VM using gcloud

  1. Define some environment variables to make commands easier to use.

    exportPROJECT_ID=your-project-id
    exportTPU_NAME=your-tpu-name
    exportZONE=us-east5-a
    exportACCELERATOR_TYPE=v5litepod-8
    exportRUNTIME_VERSION=v2-alpha-tpuv5-lite

    Environment variable descriptions

    Variable Description
    PROJECT_ID Your Google Cloud project ID. Use an existing project or create a new one.
    TPU_NAME The name of the TPU.
    ZONE The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones.
    ACCELERATOR_TYPE The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
    RUNTIME_VERSION The Cloud TPU software version.

  2. Create your TPU VM by running the following command from a Cloud Shell or your computer terminal where the Google Cloud CLI is installed.

    $gcloudcomputetpustpu-vmcreate$TPU_NAME\
    --project=$PROJECT_ID\
    --zone=$ZONE\
    --accelerator-type=$ACCELERATOR_TYPE\
    --version=$RUNTIME_VERSION

Connect to your Cloud TPU VM

Connect to your TPU VM over SSH by using the following command:

$gcloudcomputetpustpu-vmssh$TPU_NAME\
--project=$PROJECT_ID\
--zone=$ZONE

If you fail to connect to a TPU VM using SSH, it might be because the TPU VM doesn't have an external IP address. To access a TPU VM without an external IP address, follow the instructions in Connect to a TPU VM without a public IP address.

Install JAX on your Cloud TPU VM

(vm)$pipinstalljax[tpu]-fhttps://storage.googleapis.com/jax-releases/libtpu_releases.html

System check

Verify that JAX can access the TPU and can run basic operations:

  1. Start the Python 3 interpreter:

    (vm)$python3
    >>>importjax
  2. Display the number of TPU cores available:

    >>>jax.device_count()

The number of TPU cores is displayed. The number of cores displayed is dependent on the TPU version you are using. For more information, see TPU versions.

Perform a calculation

>>>jax.numpy.add(1,1)

The result of the numpy add is displayed:

Output from the command:

Array(2,dtype=int32,weak_type=True)

Exit the Python interpreter

>>>exit()

Running JAX code on a TPU VM

You can now run any JAX code you want. The Flax examples are a great place to start with running standard ML models in JAX. For example, to train a basic MNIST convolutional network:

  1. Install Flax examples dependencies:

    (vm)$pipinstall--upgradeclu
    (vm)$pipinstalltensorflow
    (vm)$pipinstalltensorflow_datasets
  2. Install Flax:

    (vm)$gitclonehttps://github.com/google/flax.git
    (vm)$pipinstall--userflax
  3. Run the Flax MNIST training script:

    (vm)$cdflax/examples/mnist
    (vm)$python3main.py--workdir=/tmp/mnist\
    --config=configs/default.py\
    --config.learning_rate=0.05\
    --config.num_epochs=5

The script downloads the dataset and starts training. The script output should look like this:

I021418:00:50.660087140369022753856train.py:146]epoch:1,train_loss:0.2421,train_accuracy:92.97,test_loss:0.0615,test_accuracy:97.88
I021418:00:52.015867140369022753856train.py:146]epoch:2,train_loss:0.0594,train_accuracy:98.16,test_loss:0.0412,test_accuracy:98.72
I021418:00:53.377511140369022753856train.py:146]epoch:3,train_loss:0.0418,train_accuracy:98.72,test_loss:0.0296,test_accuracy:99.04
I021418:00:54.727168140369022753856train.py:146]epoch:4,train_loss:0.0305,train_accuracy:99.06,test_loss:0.0257,test_accuracy:99.15
I021418:00:56.082807140369022753856train.py:146]epoch:5,train_loss:0.0252,train_accuracy:99.20,test_loss:0.0263,test_accuracy:99.18

Clean up

To avoid incurring charges to your Google Cloud account for the resources used on this page, follow these steps.

When you are done with your TPU VM, follow these steps to clean up your resources.

  1. Disconnect from the Cloud TPU instance, if you have not already done so:

    (vm)$exit

    Your prompt should now be username@projectname, showing you are in the Cloud Shell.

  2. Delete your Cloud TPU:

    $gcloudcomputetpustpu-vmdelete$TPU_NAME\
    --project=$PROJECT_ID\
    --zone=$ZONE
  3. Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.

    $gcloudcomputetpustpu-vmlist\
    --zone=$ZONE

Performance notes

Here are a few important details that are particularly relevant to using TPUs in JAX.

Padding

One of the most common causes for slow performance on TPUs is introducing inadvertent padding:

  • Arrays in the Cloud TPU are tiled. This entails padding one of the dimensions to a multiple of 8, and a different dimension to a multiple of 128.
  • The matrix multiplication unit performs best with pairs of large matrices that minimize the need for padding.

bfloat16 dtype

By default, matrix multiplication in JAX on TPUs uses bfloat16 with float32 accumulation. This can be controlled with the precision argument on relevant jax.numpy function calls (matmul, dot, einsum, etc). In particular:

  • precision=jax.lax.Precision.DEFAULT: uses mixed bfloat16 precision (fastest)
  • precision=jax.lax.Precision.HIGH: uses multiple MXU passes to achieve higher precision
  • precision=jax.lax.Precision.HIGHEST: uses even more MXU passes to achieve full float32 precision

JAX also adds the bfloat16 dtype, which you can use to explicitly cast arrays to bfloat16. For example, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

What's next

For more information about Cloud TPU, see:

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025年10月29日 UTC.