Gradient checkpointing trades compute time for memory. Instead of storing all intermediate activations during the forward pass (which eat up massive amounts of VRAM), it only keeps a few "checkpoints" and recomputes the rest during backpropagation. The result: 60-80% less activation memory at the cost of roughly 20-30% slower training.

This is how you fit a 13B model on a single RTX 4090, or run full fine-tuning of a 7B model on an A100 40GB instead of needing the 80GB variant.

The Memory Problem It Solves

During training, PyTorch stores every intermediate activation from the forward pass because it needs them to compute gradients during the backward pass. For a transformer model, this means storing the output of every attention layer, every MLP block, every normalization layer — across every token in your sequence.

For a 7B parameter model with a batch of 4 sequences at length 2048, activation memory alone can consume 15-25GB. That's on top of the model weights (~14GB in FP16), optimizer states (~28GB for AdamW), and gradients (~14GB). Total memory demand: 70-80GB.

Gradient checkpointing cuts the activation portion by 60-80%, bringing total demand down to 50-55GB — suddenly feasible on an A100 80GB or even 40GB with a smaller batch.

How It Works Under the Hood

Normal training stores everything:

Layer 1 → [save activation] → Layer 2 → [save activation] → ... → Layer N → Loss

With checkpointing, you pick a few layers to save, discard the rest:

Layer 1 → [SAVE] → Layer 2 → [discard] → Layer 3 → [discard] → Layer 4 → [SAVE] → ...

During backpropagation, when gradients reach a discarded layer, PyTorch re-runs the forward pass from the nearest checkpoint to regenerate that activation on the fly. More compute, less memory.

The sweet spot: save every Nth layer (typically every 2-4 transformer blocks). This balances memory savings against recomputation cost.

Enabling It

It's a one-liner in most frameworks.

HuggingFace Transformers:

model.gradient_checkpointing_enable()
# That's it. Seriously.

PyTorch native:

from torch.utils.checkpoint import checkpoint

class MyModel(nn.Module):
    def forward(self, x):
        # Checkpoint specific blocks
        x = checkpoint(self.block1, x, use_reentrant=False)
        x = checkpoint(self.block2, x, use_reentrant=False)
        return self.head(x)

DeepSpeed:

{
  "activation_checkpointing": {
    "partition_activations": true,
    "number_checkpoints": 12
  }
}

Real Impact on Different Workloads

ModelWithout CheckpointingWith CheckpointingMemory SavedSpeed Penalty
Llama 3 8B (LoRA, batch 4)22GB15GB32%18%
Llama 3 8B (full FT, batch 4)72GB48GB33%25%
Llama 3 13B (LoRA, batch 2)34GB21GB38%20%
Llama 3 70B (LoRA, batch 1)95GB52GB45%30%
SDXL training (batch 8)28GB16GB43%15%

The larger the model and the longer the sequence, the bigger the savings. Transformer activations scale with layers x sequence_length x hidden_dim — checkpointing hits all three.

When to Use It (and When Not To)

Use gradient checkpointing when:
- Your model barely fits in GPU memory (OOM errors with your desired batch size)
- You'd rather train slower on fewer GPUs than rent more expensive hardware
- You're doing full fine-tuning of 7B+ models on A100 40GB or smaller
- Sequence lengths are long (4K+ tokens) and activations dominate memory

Skip it when:
- Model fits comfortably with room to spare (wasting 20-30% speed for nothing)
- You're doing inference only (no activations to store)
- You're already paying for A100 80GB or H100 and have memory headroom
- Training time is the bottleneck, not memory

Combining with Other Techniques

Gradient checkpointing stacks well with:

  • Mixed precision (BF16): Halves weight and activation memory independently. Together, they can reduce total memory by 70-80%.
  • LoRA fine-tuning: Reduces trainable parameters by 99%, slashing optimizer state memory. With checkpointing on top, you can LoRA fine-tune 70B models on a single A100 80GB.
  • Gradient accumulation: Process micro-batches sequentially, accumulating gradients. Reduces per-step memory at the cost of more steps.

The full stack (BF16 + LoRA + gradient checkpointing + gradient accumulation) is how people fine-tune 70B models on consumer GPUs. It works.


Train larger models on smaller GPUs — RTX 4090 from $0.18/hr on io.net. Start training