Skip to main content

Multi-GPU and multi-node distributed training

Beta

This feature is in Beta.

This page has notebook examples for multi-node and multi-GPU distributed training using Serverless GPU compute. These examples demonstrate how to scale training across multiple GPUs and nodes for improved performance.

note

Multi-node distributed training is currently only supported on A10 GPUs. Multi-GPU distributed training is supported on both A10 and H100 GPUs.

Before running these notebooks, see the Best practices checklist.

Serverless GPU API: A10 starter

The following notebook has a basic example of how to use the Serverless GPU Python API to launch multiple A10 GPUs for distributed training.

Notebook

Open notebook in new tab

Serverless GPU API: H100 starter

The following notebook has a basic example of how to use the Serverless GPU Python API to launch multiple H100 GPUs for distributed training.

Notebook

Open notebook in new tab

Distributed training using MLflow 3.0

This notebook introduces best practices for using MLflow on Databricks for deep learning use cases on serverless GPU compute. This notebook uses the Serverless GPU API to launch distributed training of a simple classification model on a remote A10 GPU. The training is tracked as an MLflow run.

Notebook

Open notebook in new tab

Distributed training using PyTorch's Distributed Data Parallel (DDP)

The following notebook demonstrates distributed training of a simple multilayer perceptron (MLP) neural network using PyTorch's Distributed Data Parallel (DDP) module on Databricks with serverless GPU compute.

Notebook

Open notebook in new tab

Distributed training using PyTorch's Fully Sharded Data Parallel (FSDP)

The following notebook demostrates distributed training of a Transformer model with 10 million parameters using PyTorch's Fully Sharded Data Parallel (FSDP) module on Databricks with serverless GPU compute.

Notebook

Open notebook in new tab

Distributed training using Ray

This notebook demonstrates distributed training of a PyTorch ResNet model on the FashionMNIST dataset using Ray Train and Ray Data on Databricks Serverless GPU clusters. It covers setting up Unity Catalog storage, configuring Ray for multi-node GPU training, logging and registering models with MLflow, and evaluating model performance.

Notebook

Open notebook in new tab

Distributed supervised fine-tuning using TRL

This notebook demonstrates how to use the Serverless GPU Python API to run supervised fine-tuning (SFT) using the TRL library with DeepSpeed ZeRO Stage 3 optimization on a single node A10 GPU. This approach can be extended to multi-node setups.

Notebook

Open notebook in new tab