Early Development - API Unstable
Datarax is in early development and undergoing rapid iteration. Breaking changes are expected. Pin to specific commits if stability is required. We recommend waiting for a stable release (v1.0) before using Datarax in production.
Datarax (Data + Array/JAX) is an extensible data pipeline framework built for JAX-based machine learning workflows. It leverages JAX's JIT compilation, automatic differentiation, and hardware acceleration to build data loading, preprocessing, and augmentation pipelines that run on CPUs, GPUs, and TPUs.
- JAX-Native Design: All core components built on JAX's functional paradigm with Flax NNX module system for state management
- High Performance: JIT-compiled pipelines via XLA, with built-in profiling and roofline analysis
- DAG Execution Engine: Graph-based pipeline construction with branching, parallel execution, caching, and rebatching nodes
- Scalability: Multi-device and multi-host data distribution with device mesh sharding
- Determinism: Reproducible pipelines by default using Grain's Feistel cipher shuffling (O(1) memory)
- Extensibility: Custom data sources, operators, and augmentation strategies via composable NNX modules
- Benchmarking Suite: Comparative benchmarks against 12+ frameworks with calibrax-powered analysis and regression checks
- Ecosystem Integration: Works with Flax, Optax, Orbax, HuggingFace Datasets, and TensorFlow Datasets
JAX has mature libraries for models (Flax), optimizers (Optax), and checkpointing (Orbax), but lacks a dedicated data pipeline framework that operates at the same level of abstraction. Existing options are either framework-agnostic loaders that return NumPy arrays (losing JIT/autodiff benefits) or wrappers around tf.data/PyTorch that introduce cross-framework overhead. Datarax aims to fill this gap. The framework is under active development with ongoing performance optimization — the architecture is functional, but throughput and API surface are still being refined.
Every component — sources, operators, batchers, samplers, sharders — is a Flax NNX module. Pipeline state is managed through NNX's variable system, which means operators can hold learnable parameters, be serialized with Orbax, and participate in JAX transformations (jit, vmap, grad) without special handling.
Because operators are NNX modules, gradients flow through the entire pipeline. This enables approaches that are not possible with standard data loaders:
- Gradient-based augmentation search — replacing RL-based methods like AutoAugment with direct optimization
- Task-optimized preprocessing — backpropagating task loss through every processing stage
- Differentiable audio synthesis — extending the same pattern to non-vision domains
See the differentiable pipeline examples for details.
Pipelines are directed acyclic graphs, not linear chains. The >> operator composes sequential steps, | creates parallel branches, and control-flow nodes (Branch, Merge, SplitField) handle conditional and multi-path logic. The DAG executor manages scheduling, caching, and rebatching across the graph.
Shuffling uses Grain's Feistel cipher permutation, which generates a full-epoch permutation in O(1) memory without materializing the index array. Combined with explicit RNG key threading through every stochastic operator, pipelines produce identical output given the same seed — across restarts, devices, and host counts.
The benchmarking suite profiles datarax against 12+ frameworks (Grain, tf.data, PyTorch DataLoader, DALI, Ray Data, and others) across standardized scenarios. Results are converted to calibrax runs for direction-aware metrics, regression gating, and W&B export. This benchmark-driven loop is how datarax tracks progress toward competitive throughput — current results and optimization status are tracked in the benchmarking documentation.
# Basic installation uv pip install datarax # With data loading support (HuggingFace, TFDS, audio/image libs) uv pip install "datarax[data]" # With GPU support (CUDA 12) uv pip install "datarax[gpu]" # Full development installation uv pip install "datarax[all]"
# macOS CPU mode (recommended) uv pip install "datarax[all-cpu]" JAX_PLATFORMS=cpu python your_script.py # Metal GPU acceleration (experimental, M1/M2/M3+) uv pip install jax-metal JAX_PLATFORMS=metal python your_script.py
Note: Metal GPU acceleration is community-tested. CI runs on macOS with CPU only.
import jax import jax.numpy as jnp import numpy as np from flax import nnx from datarax import Pipeline from datarax.operators import ElementOperator, ElementOperatorConfig from datarax.sources import MemorySource, MemorySourceConfig from datarax.typing import Element def normalize(element: Element, key: jax.Array | None = None) -> Element: return element.update_data({"image": element.data["image"] / 255.0}) def augment(element: Element, key: jax.Array) -> Element: key1, _ = jax.random.split(key) flip = jax.random.bernoulli(key1, 0.5) new_image = jax.lax.cond( flip, lambda img: jnp.flip(img, axis=1), lambda img: img, element.data["image"], ) return element.update_data({"image": new_image}) # Create in-memory data source data = { "image": np.random.randint(0, 255, (1000, 28, 28, 1)).astype(np.float32), "label": np.random.randint(0, 10, (1000,)).astype(np.int32), } source = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(0)) # Build pipeline with DAG-based API normalizer = ElementOperator( ElementOperatorConfig(stochastic=False), fn=normalize, rngs=nnx.Rngs(0), ) augmenter = ElementOperator( ElementOperatorConfig(stochastic=True, stream_name="augmentations"), fn=augment, rngs=nnx.Rngs(42), ) pipeline = ( Pipeline(source=source, stages=[normalizer, augmenter], batch_size=32, rngs=nnx.Rngs(0)) ) # Process batches for i, batch in enumerate(pipeline): if i >= 3: break print(f"Batch {i}: images {batch['image'].shape}, labels {batch['label'].shape}")
# Define additional operators def invert(element: Element, key=None) -> Element: return element.update_data({"image": 1.0 - element.data["image"]}) inverter = ElementOperator( ElementOperatorConfig(stochastic=False), fn=invert, rngs=nnx.Rngs(0), ) # Build a branching DAG: # - augment and normalize each consume the source independently # - merge takes both outputs and averages them class Merge(nnx.Module): def __call__(self, augmented, clean): return { "image": (augmented["image"] + clean["image"]) / 2, "label": clean["label"], } complex_pipeline = Pipeline.from_dag( source=source, nodes={"augment": augmenter, "normalize": normalizer, "merge": Merge()}, edges={"augment": [], "normalize": [], "merge": ["augment", "normalize"]}, sink="merge", batch_size=32, rngs=nnx.Rngs(0), )
src/datarax/
core/ # Base modules: DataSourceModule, OperatorModule, Element, Batcher, Sampler, Sharder
pipeline/ # Pipeline (nnx.Module): linear stages and Pipeline.from_dag for branching
sources/ # MemorySource, TFDS (eager/streaming), HuggingFace (eager/streaming), ArrayRecord, MixedSource, StreamingDiskSource
operators/ # ElementOperator, MapOperator, CompositeOperator, modality-specific (image, audio)
strategies/ # Sequential, Parallel, Branching, Ensemble, Merging composition strategies
samplers/ # Sequential, Shuffle (Feistel cipher), Range, EpochAware, SlidingWindow, BufferSampler
batching/ # DefaultBatcher with buffer state management
sharding/ # ArraySharder, JaxProcessSharder for multi-device distribution
distributed/ # DeviceMesh, DataParallel for multi-host training
checkpoint/ # Orbax integration (NNX-standard checkpoint pattern)
monitoring/ # MetricsCollector, callbacks, reporters (console/file)
performance/ # Roofline analysis, XLA optimization utilities
control/ # Prefetcher for asynchronous data loading
memory/ # Shared memory manager for multi-process data sharing
workers/ # Reserved namespace for the planned multiprocessing backend
config/ # TOML-based configuration system with schema validation
cli/ # datarax CLI entry point
utils/ # PyTree utilities, external integration helpers
Datarax includes a benchmarking suite for comparison against 12+ data loading frameworks across a range of workload scenarios (vision, NLP, tabular, multimodal, distributed).
# Install benchmark dependencies (adds PyTorch, DALI, Ray, etc.) uv sync --extra benchmark # Optional: install calibrax with W&B support explicitly uv pip install "calibrax[wandb] @ git+https://github.com/avitai/calibrax.git" # Run benchmarks locally uv run python -m benchmarks.runners.full_runner --platform cpu --repetitions 5 # Run on cloud (SkyPilot) sky launch benchmarks/sky/gpu-benchmark.yaml --env WANDB_API_KEY=$WANDB_API_KEY
Benchmark results are exported to W&B with charts, gap analysis, stability reports, and raw result artifacts. See Benchmarking Guide for methodology and cloud deployment.
Datarax uses uv as its package manager:
# Clone and setup git clone https://github.com/avitai/datarax.git cd datarax # Automatic setup ./setup.sh && source activate.sh # Or manual install uv sync --extra dev
# CPU-only (most stable) JAX_PLATFORMS=cpu uv run pytest # Include benchmark test suite in the same run JAX_PLATFORMS=cpu uv run pytest --all-suites # Specific module JAX_PLATFORMS=cpu uv run pytest tests/sources/test_memory_source_module.py
# Build and run docker build -t datarax:latest . docker run --rm --gpus all datarax:latest python -c "import datarax, jax; print(jax.devices())" # Benchmark images docker build -f benchmarks/docker/Dockerfile.gpu -t datarax-bench:gpu .
See Docker Guide for full details.
- Installation Guide
- Quick Start
- Core Concepts
- User Guide
- API Reference
- Examples
- Benchmarking
- Contributing
- Docker
Datarax is licensed under the MIT License.