Multi-GPU and multi-node distributed training
This feature is in Beta. Workspace admins can control access to this feature from the Previews page. See Manage Databricks previews.
This page has notebook examples for multi-node and multi-GPU distributed training using Serverless GPU compute. These examples demonstrate how to scale training across multiple GPUs and nodes for improved performance.
Multi-node distributed training is currently only supported on A10 GPUs. Multi-GPU distributed training is supported on both A10 and H100 GPUs.
Choose your parallelism technique
When scaling your model training across multiple GPUs, choosing the right parallelism technique depends on your model size, available GPU memory, and performance requirements.
Technique | When to use |
|---|---|
DDP (Distributed Data Parallel) | Full model fits in single GPU memory; need to scale data throughput |
FSDP (Fully Sharded Data Parallel) | Very large models that don't fit in single GPU memory |
DeepSpeed ZeRO | Large models with advanced memory optimization needs |
For detailed information about each technique, see DDP, FSDP, and DeepSpeed.
Example notebooks by technique and framework
The following table organizes example notebooks by the framework/library you're using and the parallelism technique applied. Multiple notebooks may appear in a single cell.
Framework/Library | DDP examples | FSDP examples | DeepSpeed examples |
|---|---|---|---|
PyTorch (native) | — | ||
— | — | ||
— | — | ||
— | — | ||
— | — | ||
— | — |
Get started
Use the following tutorials to get started with the serverless GPU Python library for distributed training:
Tutorial | Description |
|---|---|
This notebook demonstrates how to use serverless GPU compute to run GPU workloads on A10 GPUs directly from Databricks notebooks. Learn how to use the Serverless GPU Python library to execute functions on single or multiple GPUs for distributed training. | |
Learn how to use Databricks Serverless GPU compute with H100 accelerators to run distributed GPU workloads using the serverless_gpu Python library. |