Distributed training with TorchDistributor

This article describes how to perform distributed training on PyTorch ML models using TorchDistributor.

TorchDistributor is an open-source module in PySpark that helps users do distributed training with PyTorch on their Spark clusters, so it lets you launch PyTorch training jobs as Spark jobs. Under-the-hood, it initializes the environment and the communication channels between the workers and utilizes the CLI command torch.distributed.run to run distributed training across the worker nodes.

The TorchDistributor API supports the methods shown in the following table.

Method and signature


init(self, num_processes, local_mode, use_gpu)

Create an instance of TorchDistributor.

run(self, main, *args)

Runs distributed training by invoking main(**kwargs) if main is a function and runs the CLI command torchrun main *args if main is a file path.


  • Spark 3.4

  • Databricks Runtime 13.0 ML or above

Development workflow for notebooks

If the model creation and training process happens entirely from a notebook on your local machine or a Databricks Notebook, you only have to make minor changes to get your code ready for distributed training.

  1. Prepare single node code: Prepare and test the single node code with PyTorch, PyTorch Lightning, or other frameworks that are based on PyTorch/PyTorch Lightning like, the HuggingFace Trainer API.

  2. Prepare code for standard distributed training: You need to convert your single process training to distributed training. Have this distributed code all encompassed within one training function that you can use with the TorchDistributor.

  3. Move imports within training function: Add the necessary imports, such as import torch, within the training function. Doing so allows you to avoid common pickling errors. Furthermore, the device_id that models and data are be tied to is determined by:

    device_id = int(os.environ["LOCAL_RANK"])
  4. Launch distributed training: Instantiate the TorchDistributor with the desired parameters and call .run(*args) to launch training.

The following is a training code example:

from pyspark.ml.torch.distributor import TorchDistributor

def train(learning_rate, use_gpu):
  import torch
  import torch.distributed as dist
  import torch.nn.parallel.DistributedDataParallel as DDP
  from torch.utils.data import DistributedSampler, DataLoader

  backend = "nccl" if use_gpu else "gloo"
  device = int(os.environ["LOCAL_RANK"]) if use_gpu  else "cpu"
  model = DDP(createModel(), **kwargs)
  sampler = DistributedSampler(dataset)
  loader = DataLoader(dataset, sampler=sampler)

  output = train(model, loader, learning_rate)
  return output

distributor = TorchDistributor(num_processes=2, local_mode=False, use_gpu=True)
distributor.run(train, 1e-3, True)

Migrate training from external repositories

If you have an existing distributed training procedure stored in an external repository, you can easily migrate to Databricks by doing the following:

  1. Import the repository: Import the external repository as a Databricks Git folder.

  2. Create a new notebook Initialize a new Databricks Notebook within the repository.

  3. Launch distributed training In a notebook cell, call TorchDistributor like the following:

from pyspark.ml.torch.distributor import TorchDistributor

train_file = "/path/to/train.py"
args = ["--learning_rate=0.001", "--batch_size=16"]
distributor = TorchDistributor(num_processes=2, local_mode=False, use_gpu=True)
distributor.run(train_file, *args)


A common error for the notebook workflow is that objects cannot be found or pickled when running distributed training. This can happen when the library import statements are not distributed to other executors.

To avoid this issue, include all import statements (for example, import torch) both at the top of the training function that is called with TorchDistributor(...).run(<func>) and inside any other user-defined functions called in the training method.

CUDA failure: peer access is not supported between these two devices

This is a potential error on the G5 suite of GPUs on AWS. To resolve this error, add the following snippet in your training code:

import os
os.environ["NCCL_P2P_DISABLE"] = "1"

Example notebooks

The following notebook examples demonstrate how to perform distributed training with PyTorch.

End-to-end distributed training on Databricks notebook

Open notebook in new tab

Distributed fine-tuning a Hugging Face model notebook

Open notebook in new tab

Distributed training on a PyTorch File notebook

Open notebook in new tab

Distributed training using PyTorch Lightning notebook

Open notebook in new tab

Distributed data loading using Petastorm notebook

Open notebook in new tab