Note: Alpa is not actively maintained currently. It is available as a research artifact. The core algorithm in Alpa has been merged into XLA, which is still being maintained. https://github.com/openxla/xla/tree/main/xla/hlo/experimental/auto_sharding
Alpa is a system for training and serving large-scale neural networks.
Scaling neural networks to hundreds of billions of parameters has enabled dramatic breakthroughs such as GPT-3, but training and serving these large-scale neural networks require complicated distributed system techniques. Alpa aims to automate large-scale distributed training and serving with just a few lines of code.
The key features of Alpa include:
π» Automatic Parallelization. Alpa automatically parallelizes users' single-device code on distributed clusters with data, operator, and pipeline parallelism.
π Excellent Performance. Alpa achieves linear scaling on training models with billions of parameters on distributed clusters.
β¨ Tight Integration with Machine Learning Ecosystem. Alpa is backed by open-source, high-performance, and production-ready libraries such as Jax, XLA, and Ray.
The code below shows how to use huggingface/transformers interface and Alpa distributed backend for large model inference. Detailed documentation is in Serving OPT-175B using Alpa.
from transformers import AutoTokenizer from llm_serving.model.wrapper import get_model # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b") tokenizer.add_bos_token = False # Load the model. Alpa automatically downloads the weights to the specificed path model = get_model(model_name="alpa/opt-2.7b", path="~/opt_weights/") # Generate prompt = "Paris is the capital city of" input_ids = tokenizer(prompt, return_tensors="pt").input_ids output = model.generate(input_ids=input_ids, max_length=256, do_sample=True) generated_string = tokenizer.batch_decode(output, skip_special_tokens=True) print(generated_string)
Use Alpa's decorator @parallelize
to scale your single-device training code to distributed clusters.
Check out the documentation site and
examples folder
for installation instructions, tutorials, examples, and more.
import alpa # Parallelize the training step in Jax by simply using a decorator @alpa.parallelize def train_step(model_state, batch): def loss_func(params): out = model_state.forward(params, batch["x"]) return jnp.mean((out - batch["y"]) ** 2) grads = grad(loss_func)(model_state.params) new_model_state = model_state.apply_gradient(grads) return new_model_state # The training loop now automatically runs on your designated cluster model_state = create_train_state() for batch in data_loader: model_state = train_step(model_state, batch)
- Connect to Alpa developers via the Alpa slack.
- Please read the contributor guide if you are interested in contributing code.
Alpa is licensed under the Apache-2.0 license.