Tutorial: Exporting StableHLO from PyTorch
Stay organized with collections
Save and categorize content based on your preferences.
PyTorch is a popular library for building deep learning models. In this tutorial, you will learn to export a PyTorch model to StableHLO, and then directly to TensorFlow SavedModel.
Tutorial Setup
Install required dependencies
We use torch and torchvision to get a ResNet18 model model, and torch_xla to export it to StableHLO.
We also need to install tensorflow to work with SavedModel, and recommend using tensorflow-cpu or tf-nightly for this tutorial.
pipinstalltorch_xla==2.5.0torch==2.5.0torchvision==0.20.0tensorflow-cpuExport PyTorch model to StableHLO
The general set of steps for exporting a PyTorch model to StableHLO is:
- Use PyTorch's
torch.exportAPI to generate an exported FX graph (i.e.,ExportedProgram) - Use PyTorch/XLA's
torch_xla.stablehloAPI to convert theExportedProgramto StableHLO
Export model to FX graph using torch.export
This step uses vanilla PyTorch APIs to export a resnet18 model from torchvision.
Sample inputs are required for graph tracing, we use a tensor<4x3x224x224xf32> in this case.
importtorch
importtorchvision
fromtorch.exportimport export
resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
sample_input = (torch.randn(4, 3, 224, 224), )
exported = export(resnet18, sample_input)
Export FX graph to StableHLO using torch_xla.stablehlo
Once we have an exported FX graph, we can convert it to StableHLO using exported_program_to_stablehlo in the torch_xla.stablehlo module.
We can then look at the exported StableHLO program with get_stablehlo_text.
fromtorch_xla.stablehloimport exported_program_to_stablehlo
stablehlo_program = exported_program_to_stablehlo(exported)
print(stablehlo_program.get_stablehlo_text('forward')[0:4000],"\n...")
WARNING:root:Defaulting to PJRT_DEVICE=CPU
module @IrToHlo.484 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<1000xf32>, %arg1: tensor<1000x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512xf32>, %arg4: tensor<512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512x256x1x1xf32>, %arg7: tensor<256xf32>, %arg8: tensor<256xf32>, %arg9: tensor<256xf32>, %arg10: tensor<256xf32>, %arg11: tensor<256x128x1x1xf32>, %arg12: tensor<128xf32>, %arg13: tensor<128xf32>, %arg14: tensor<128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x64x1x1xf32>, %arg17: tensor<64xf32>, %arg18: tensor<64xf32>, %arg19: tensor<64xf32>, %arg20: tensor<64xf32>, %arg21: tensor<64x3x7x7xf32>, %arg22: tensor<4x3x224x224xf32>, %arg23: tensor<64xf32>, %arg24: tensor<64xf32>, %arg25: tensor<64xf32>, %arg26: tensor<64xf32>, %arg27: tensor<64x64x3x3xf32>, %arg28: tensor<64xf32>, %arg29: tensor<64xf32>, %arg30: tensor<64xf32>, %arg31: tensor<64xf32>, %arg32: tensor<64x64x3x3xf32>, %arg33: tensor<64xf32>, %arg34: tensor<64xf32>, %arg35: tensor<64xf32>, %arg36: tensor<64xf32>, %arg37: tensor<64x64x3x3xf32>, %arg38: tensor<64xf32>, %arg39: tensor<64xf32>, %arg40: tensor<64xf32>, %arg41: tensor<64xf32>, %arg42: tensor<64x64x3x3xf32>, %arg43: tensor<128xf32>, %arg44: tensor<128xf32>, %arg45: tensor<128xf32>, %arg46: tensor<128xf32>, %arg47: tensor<128x128x3x3xf32>, %arg48: tensor<128xf32>, %arg49: tensor<128xf32>, %arg50: tensor<128xf32>, %arg51: tensor<128xf32>, %arg52: tensor<128x64x3x3xf32>, %arg53: tensor<128xf32>, %arg54: tensor<128xf32>, %arg55: tensor<128xf32>, %arg56: tensor<128xf32>, %arg57: tensor<128x128x3x3xf32>, %arg58: tensor<128xf32>, %arg59: tensor<128xf32>, %arg60: tensor<128xf32>, %arg61: tensor<128xf32>, %arg62: tensor<128x128x3x3xf32>, %arg63: tensor<256xf32>, %arg64: tensor<256xf32>, %arg65: tensor<256xf32>, %arg66: tensor<256xf32>, %arg67: tensor<256x256x3x3xf32>, %arg68: tensor<256xf32>, %arg69: tensor<256xf32>, %arg70: tensor<256xf32>, %arg71: tensor<256xf32>, %arg72: tensor<256x128x3x3xf32>, %arg73: tensor<256xf32>, %arg74: tensor<256xf32>, %arg75: tensor<256xf32>, %arg76: tensor<256xf32>, %arg77: tensor<256x256x3x3xf32>, %arg78: tensor<256xf32>, %arg79: tensor<256xf32>, %arg80: tensor<256xf32>, %arg81: tensor<256xf32>, %arg82: tensor<256x256x3x3xf32>, %arg83: tensor<512xf32>, %arg84: tensor<512xf32>, %arg85: tensor<512xf32>, %arg86: tensor<512xf32>, %arg87: tensor<512x512x3x3xf32>, %arg88: tensor<512xf32>, %arg89: tensor<512xf32>, %arg90: tensor<512xf32>, %arg91: tensor<512xf32>, %arg92: tensor<512x256x3x3xf32>, %arg93: tensor<512xf32>, %arg94: tensor<512xf32>, %arg95: tensor<512xf32>, %arg96: tensor<512xf32>, %arg97: tensor<512x512x3x3xf32>, %arg98: tensor<512xf32>, %arg99: tensor<512xf32>, %arg100: tensor<512xf32>, %arg101: tensor<512xf32>, %arg102: tensor<512x512x3x3xf32>) -> tensor<4x1000xf32> {
%cst = stablehlo.constant dense<0.0204081628> : tensor<4x512xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<4x512x7x7xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<4x256x14x14xf32>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<4x128x28x28xf32>
%cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x56x56xf32>
%cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x112x112xf32>
%cst_5 = stablehlo.constant dense<0xFF800000> : tensor<f32>
%cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%0 = stablehlo.convolution(%arg22, %arg21) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 2], pad = [[3, 3], [3, 3]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<4x3x224x224xf32>, tensor<64x3x7x7xf32>) -> tensor<4x64x112x112xf32>
%output, %batch_mean, %batch_var = "stablehlo.ba
...
Tip:
Dynamic batch dimensions can be specified as a part of the initial torch.export step.
torch_xla's support for exporting dynamic models is limited, for these cases we recommend using torch_xla2 for this. This lowering path leverages JAX for lowering to StableHLO, and has high opset coverage with much broader support for exported programs with dynamic shapes.
Save and reload StableHLO
StableHLOGraphModule has methods to save and load StableHLO artifacts.
This stores StableHLO portable bytecode artifacts which have complete forward and backward compatibility guarantees.
fromtorch_xla.stablehloimport StableHLOGraphModule
# Save to tmp
stablehlo_program.save('/tmp/stablehlo_dir')
!ls /tmp/stablehlo_dir
!ls /tmp/stablehlo_dir/functions
constants data functions forward.bytecode forward.meta forward.mlir
# Reload and execute - Stable serialization, forward / backward compatible.
reloaded = StableHLOGraphModule.load('/tmp/stablehlo_dir')
print(reloaded(sample_input[0]))
tensor([[-2.3258, -0.9606, -0.9439, ..., 0.3519, 0.6261, 2.3971], [ 1.6479, -0.0268, 1.0511, ..., -1.2512, 2.2042, 1.8865], [ 0.1756, -0.3658, -0.0651, ..., 0.0661, 2.1358, 0.5009], [-1.6709, -0.7363, -2.0963, ..., -1.3716, 0.3321, -0.9199]], device='xla:0')
Note: You can also use convenience wrappers like save_torch_model_as_stablehlo to export and save. Learn more in the PyTorch/XLA documentation on exporting to StableHLO.
Export to TensorFlow SavedModel
It is common to want to export a StableHLO model to TensorFlow SavedModel for interoperability with existing compilation pipelines, existing TensorFlow tooling, or serving via TensorFlow Serving.
PyTorch/XLA's torch_xla.tf_saved_model_integration module makes it easy to pack StableHLO into a SavedModel, which can be loaded back and executed.
Export to SavedModel with torch_xla.tf_saved_model_integration
We use the save_torch_module_as_tf_saved_model function for this conversion, which uses the torch.export and torch_xla.stablehlo.exported_program_to_stablehlo functions under the hood.
The input to the API is a PyTorch model, and we use the same resnet18 from the previous examples.
fromtorch_xla.tf_saved_model_integrationimport save_torch_module_as_tf_saved_model
save_torch_module_as_tf_saved_model(
resnet18, # original pytorch torch.nn.Module
sample_input, # sample inputs used to trace
'/tmp/resnet_tf' # directory for tf.saved_model
)
!ls /tmp/resnet_tf/
assets fingerprint.pb saved_model.pb variables
Reload and call the SavedModel
Now we can load that SavedModel and compile using our sample_input from a previous example.
Note: The restored model does not require PyTorch or PyTorch/XLA to run, just XLA.
importtensorflowastf
loaded_m = tf.saved_model.load('/tmp/resnet_tf')
print(loaded_m.f(tf.constant(sample_input[0].numpy())))
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1730760467.760638 8492 service.cc:148] XLA service 0x7ede002016e0 initialized for platform Host (this does not guarantee that XLA will be used). Devices: I0000 00:00:1730760467.760777 8492 service.cc:156] StreamExecutor device (0): Host, Default Version I0000 00:00:1730760468.613723 8492 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. [<tf.Tensor: shape=(4, 1000), dtype=float32, numpy= array([[-2.3257551 , -0.96061766, -0.9439326 , ..., 0.35189423, 0.62605226, 2.3971176 ], [ 1.6479174 , -0.02676968, 1.0511047 , ..., -1.2511721 , 2.2041895 , 1.8865337 ], [ 0.17559683, -0.365776 , -0.06507193, ..., 0.06606296, 2.135755 , 0.500913 ], [-1.6709077 , -0.7362997 , -2.0962732 , ..., -1.3716122 , 0.33205754, -0.91991633]], dtype=float32)>]
Troubleshooting
Version mismatch
Ensure that you have the same version of PyTorch/XLA and PyTorch. Version mismatch can result in import errors, as well as some runtime issues.
Export bugs
If your program fails to export due to a bug in the PyTorch/XLA bridge, open an issue on GitHub with a reproducible example:
- Issues in
torch.export: Report these in the upstream pytorch/pytorch repository - Issues in
torch_xla.stablehlo: Open a ticket on pytorch/xla repository