Yes, and JAX is arguably the best framework for multi-GPU training if you're willing to learn its paradigm. While PyTorch dominates mindshare, JAX's XLA compiler and functional design make it exceptionally good at distributed training — it was literally built for Google's TPU pods, and those same abstractions map beautifully to multi-GPU setups on io.net.
The catch: JAX thinks about distribution differently than PyTorch. There's no DistributedDataParallel wrapper. Instead, you express parallelism declaratively using jax.pmap or the newer jax.sharding APIs, and JAX's compiler figures out the optimal communication pattern.
Getting Started on io.net
Spin up a multi-GPU instance with CUDA 12.x and install JAX with GPU support:
pip install jax[cuda12] flax optax
Verify JAX sees all GPUs:
import jax
print(jax.devices()) # Should list all GPUs
# [CudaDevice(id=0), CudaDevice(id=1), ..., CudaDevice(id=7)]
Data Parallelism with jax.pmap
The simplest form of distributed training in JAX. pmap (parallel map) replicates a function across all GPUs, each processing a different shard of the batch:
import jax
import jax.numpy as jnp
from jax import pmap
@pmap
def train_step(params, batch):
def loss_fn(p):
logits = model.apply(p, batch['input'])
return jnp.mean(cross_entropy(logits, batch['label']))
loss, grads = jax.value_and_grad(loss_fn)(params)
# All-reduce gradients automatically via pmap
grads = jax.lax.pmean(grads, axis_name='batch')
params = jax.tree.map(lambda p, g: p - 0.001 * g, params, grads)
return params, loss
# Replicate params across GPUs
params = jax.tree.map(lambda x: jnp.stack([x] * num_gpus), params)
# Shard batch across GPUs (first axis = device)
batch = shard_batch(batch, num_gpus)
# Train — this runs on all GPUs simultaneously
params, loss = train_step(params, batch)
That jax.lax.pmean call is where the magic happens — it triggers an AllReduce across devices, averaging gradients just like PyTorch DDP, but expressed as a pure function.
Modern Approach: Named Sharding (JAX 0.4+)
For more complex parallelism strategies (model + data parallelism, expert parallelism for MoE), JAX's named sharding API is more flexible:
from jax.sharding import Mesh, PartitionSpec, NamedSharding
# Define a 2D mesh: 4 data-parallel × 2 model-parallel
devices = jax.devices()
mesh = Mesh(np.array(devices).reshape(4, 2), ('data', 'model'))
# Shard model weights across 'model' axis, replicate across 'data'
param_sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))
# Shard batch across 'data' axis
batch_sharding = NamedSharding(mesh, PartitionSpec('data', None))
# JAX compiler handles communication automatically
@jax.jit
def train_step(params, batch):
# XLA inserts AllReduce, AllGather as needed
...
This declarative approach lets you experiment with different parallelism configurations by just changing the mesh layout — no code changes to the training loop.
Why JAX for Distributed Training
A few things JAX does better than alternatives:
XLA compilation. JAX compiles your entire training step into a single XLA graph, which the compiler optimizes globally — fusing operations, overlapping communication with compute, and eliminating unnecessary copies. PyTorch's torch.compile is catching up, but JAX has years of head start.
Functional purity. No global state, no in-place mutations. This makes it trivial for the compiler to reason about data dependencies and parallelize safely. The tradeoff is a steeper learning curve if you're coming from PyTorch.
FSDP-like sharding for free. JAX's sharding API natively supports fully-sharded data parallelism (splitting optimizer states, gradients, and parameters across GPUs) without a separate library. You just specify the partition spec.
Reproducibility. JAX's explicit PRNG handling means distributed training produces identical results regardless of GPU count, which is valuable for research.
Recommended GPU Configurations
| Workload | Config on io.net | Monthly cost (24/7) |
|---|---|---|
| JAX research (small models) | 4x RTX 4090 | $518 |
| JAX training (7-13B models) | 8x A100 80GB | $8,597 |
| JAX large-scale (70B+) | 8x H100 SXM | $12,672 |
JAX Ecosystem for Training
- Flax: Neural network library (replaces PyTorch's
nn.Module) - Optax: Optimizers (Adam, SGD, schedules, gradient clipping)
- Orbax: Checkpointing (distributed checkpoint saving/loading)
- MaxText: Google's reference implementation for training LLMs with JAX
- T5X: Training framework used for PaLM, Gemini research
Run JAX on io.net — multi-GPU clusters with NVLink, CUDA 12.x pre-installed. Deploy cluster
