MaxDiffusion inference on v6e TPUs

This tutorial shows how to serve MaxDiffusion models on TPU v6e. In this tutorial, you generate images using the Stable Diffusion XL model.

Before you begin

Prepare to provision a TPU v6e with 4 chips:

  1. Follow Set up the Cloud TPU environment guide to set up a Google Cloud project, configure the Google Cloud CLI, enable the Cloud TPU API, and ensure you have access to use Cloud TPUs.

  2. Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.

    gcloudauthlogin
    gcloudconfigsetprojectPROJECT_ID
    gcloudconfigsetcompute/zoneZONE

Secure capacity

When you are ready to secure TPU capacity, see Cloud TPU Quotas for more information about the Cloud TPU quotas. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.

Provision the Cloud TPU environment

You can provision TPU VMs with GKE, with GKE and XPK, or as queued resources.

Prerequisites

  • Verify that your project has enough TPUS_PER_TPU_FAMILY quota, which specifies the maximum number of chips you can access within your Google Cloud project.
  • Verify that your project has enough TPU quota for:
    • TPU VM quota
    • IP address quota
    • Hyperdisk Balanced quota
  • User project permissions

Provision a TPU v6e

gcloudalphacomputetpusqueued-resourcescreateQUEUED_RESOURCE_ID\
--node-idTPU_NAME\
--projectPROJECT_ID\
--zoneZONE\
--accelerator-typev6e-4\
--runtime-versionv2-alpha-tpuv6e\
--service-accountSERVICE_ACCOUNT

Use the list or describe commands to query the status of your queued resource.

gcloudalphacomputetpusqueued-resourcesdescribeQUEUED_RESOURCE_ID\
--project=PROJECT_ID--zone=ZONE

For a complete list of queued resource request statuses, see the Queued Resources documentation.

Connect to the TPU using SSH

gcloudcomputetpustpu-vmsshTPU_NAME

Create a Conda environment

  1. Create a directory for Miniconda:

    mkdir-p~/miniconda3
  2. Download the Miniconda installer script:

    wgethttps://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh-O~/miniconda3/miniconda.sh
  3. Install Miniconda:

    bash~/miniconda3/miniconda.sh-b-u-p~/miniconda3
  4. Remove the Miniconda installer script:

    rm-rf~/miniconda3/miniconda.sh
  5. Add Miniconda to your PATH variable:

    exportPATH="$HOME/miniconda3/bin:$PATH"
  6. Reload ~/.bashrc to apply the changes to the PATH variable:

    source~/.bashrc
  7. Create a new Conda environment:

    condacreate-ntpupython=3.10
  8. Activate the Conda environment:

    sourceactivatetpu

Set up MaxDiffusion

  1. Clone the MaxDiffusion GitHub repository and navigate to the MaxDiffusion directory:

    gitclonehttps://github.com/google/maxdiffusion.git&&cdmaxdiffusion
  2. Switch to the mlperf-4.1 branch:

    gitcheckoutmlperf4.1
  3. Install MaxDiffusion:

    pipinstall-e.
  4. Install dependencies:

    pipinstall-rrequirements.txt
  5. Install JAX:

    pipinstalljax[tpu]==0.4.34jaxlib==0.4.34ml-dtypes==0.2.0-ihttps://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/-fhttps://storage.googleapis.com/jax-releases/libtpu_releases.html
  6. Install additional dependencies:

    pipinstallhuggingface_hub==0.25absl-pyflaxtensorboardXgoogle-cloud-storagetorchtensorflowtransformers

Generate images

  1. Set environment variables to configure the TPU runtime:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. Generate images using the prompt and configurations defined in src/maxdiffusion/configs/base_xl.yml:

    python-msrc.maxdiffusion.generate_sdxlsrc/maxdiffusion/configs/base_xl.ymlrun_name="my_run"

    When the images have been generated, be sure to clean up the TPU resources.

Clean up

Delete the TPU:

gcloudcomputetpusqueued-resourcesdeleteQUEUED_RESOURCE_ID\
--projectPROJECT_ID\
--zoneZONE\
--force\
--async

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025年10月13日 UTC.