MaxDiffusion inference on v6e TPUs
This tutorial shows how to serve MaxDiffusion models on TPU v6e. In this tutorial, you generate images using the Stable Diffusion XL model.
Before you begin
Prepare to provision a TPU v6e with 4 chips:
Follow Set up the Cloud TPU environment guide to set up a Google Cloud project, configure the Google Cloud CLI, enable the Cloud TPU API, and ensure you have access to use Cloud TPUs.
Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.
gcloudauthlogin gcloudconfigsetprojectPROJECT_ID gcloudconfigsetcompute/zoneZONE
Secure capacity
When you are ready to secure TPU capacity, see Cloud TPU Quotas for more information about the Cloud TPU quotas. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.
Provision the Cloud TPU environment
You can provision TPU VMs with GKE, with GKE and XPK, or as queued resources.
Prerequisites
- Verify that your project has enough
TPUS_PER_TPU_FAMILY
quota, which specifies the maximum number of chips you can access within your Google Cloud project. - Verify that your project has enough TPU quota for:
- TPU VM quota
- IP address quota
- Hyperdisk Balanced quota
- User project permissions
- If you are using GKE with XPK, see Cloud Console Permissions on the user or service account for the permissions needed to run XPK.
Provision a TPU v6e
gcloudalphacomputetpusqueued-resourcescreateQUEUED_RESOURCE_ID\ --node-idTPU_NAME\ --projectPROJECT_ID\ --zoneZONE\ --accelerator-typev6e-4\ --runtime-versionv2-alpha-tpuv6e\ --service-accountSERVICE_ACCOUNT
Use the list
or describe
commands
to query the status of your queued resource.
gcloudalphacomputetpusqueued-resourcesdescribeQUEUED_RESOURCE_ID\ --project=PROJECT_ID--zone=ZONE
For a complete list of queued resource request statuses, see the Queued Resources documentation.
Connect to the TPU using SSH
gcloudcomputetpustpu-vmsshTPU_NAME
Create a Conda environment
Create a directory for Miniconda:
mkdir-p~/miniconda3
Download the Miniconda installer script:
wgethttps://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh-O~/miniconda3/miniconda.sh
Install Miniconda:
bash~/miniconda3/miniconda.sh-b-u-p~/miniconda3
Remove the Miniconda installer script:
rm-rf~/miniconda3/miniconda.sh
Add Miniconda to your
PATH
variable:exportPATH="$HOME/miniconda3/bin:$PATH"
Reload
~/.bashrc
to apply the changes to thePATH
variable:source~/.bashrc
Create a new Conda environment:
condacreate-ntpupython=3.10
Activate the Conda environment:
sourceactivatetpu
Set up MaxDiffusion
Clone the MaxDiffusion GitHub repository and navigate to the MaxDiffusion directory:
gitclonehttps://github.com/google/maxdiffusion.git&&cdmaxdiffusion
Switch to the
mlperf-4.1
branch:gitcheckoutmlperf4.1
Install MaxDiffusion:
pipinstall-e.
Install dependencies:
pipinstall-rrequirements.txt
Install JAX:
pipinstalljax[tpu]==0.4.34jaxlib==0.4.34ml-dtypes==0.2.0-ihttps://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/-fhttps://storage.googleapis.com/jax-releases/libtpu_releases.html
Install additional dependencies:
pipinstallhuggingface_hub==0.25absl-pyflaxtensorboardXgoogle-cloud-storagetorchtensorflowtransformers
Generate images
Set environment variables to configure the TPU runtime:
LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
Generate images using the prompt and configurations defined in
src/maxdiffusion/configs/base_xl.yml
:python-msrc.maxdiffusion.generate_sdxlsrc/maxdiffusion/configs/base_xl.ymlrun_name="my_run"
When the images have been generated, be sure to clean up the TPU resources.
Clean up
Delete the TPU:
gcloudcomputetpusqueued-resourcesdeleteQUEUED_RESOURCE_ID\ --projectPROJECT_ID\ --zoneZONE\ --force\ --async