
Distributed deep learning training using PyTorch with HorovodRunner for MNIST

This notebook illustrates the use of HorovodRunner for distributed training using PyTorch. It first shows how to train a model on a single node, and then shows how to adapt the code using HorovodRunner for distributed training. The notebook runs on CPU and GPU clusters.


Databricks Runtime 7.0 ML or above.
HorovodRunner is designed to improve model training performance on clusters with multiple workers, but multiple workers are not required to run this notebook.

Set up checkpoint location

The next cell creates a directory for saved checkpoint models. Databricks recommends saving training data under dbfs:/ml, which maps to file:/dbfs/ml on driver and worker nodes.

PYTORCH_DIR = '/dbfs/ml/horovod_pytorch'

Prepare single-node code

First, create single-node PyTorch code. This is modified from the Horovod PyTorch MNIST Example.

Define a simple convolutional network

import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x,
        x = self.fc2(x)
        return F.log_softmax(x)

Configure single-node training

# Specify training parameters
batch_size = 100
num_epochs = 3
momentum = 0.5
log_interval = 100
def train_one_epoch(model, device, data_loader, optimizer, epoch):
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target =,
        output = model(data)
        loss = F.nll_loss(output, target)
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(data_loader) * len(data),
                100. * batch_idx / len(data_loader), loss.item()))

Create methods for saving and loading model checkpoints

def save_checkpoint(log_dir, model, optimizer, epoch):
  filepath = log_dir + '/checkpoint-{epoch}.pth.tar'.format(epoch=epoch)
  state = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
  }, filepath)
def load_checkpoint(log_dir, epoch=num_epochs):
  filepath = log_dir + '/checkpoint-{epoch}.pth.tar'.format(epoch=epoch)
  return torch.load(filepath)
def create_log_dir():
  log_dir = os.path.join(PYTORCH_DIR, str(time()), 'MNISTDemo')
  return log_dir

Run single-node training with PyTorch

import torch.optim as optim
from torchvision import datasets, transforms
from time import time
import os
single_node_log_dir = create_log_dir()
print("Log directory:", single_node_log_dir)
def train(learning_rate):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  train_dataset = datasets.MNIST(
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
  data_loader =, batch_size=batch_size, shuffle=True)
  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
  for epoch in range(1, num_epochs + 1):
    train_one_epoch(model, device, data_loader, optimizer, epoch)
    save_checkpoint(single_node_log_dir, model, optimizer, epoch)
def test(log_dir):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  loaded_model = Net().to(device)
  checkpoint = load_checkpoint(log_dir)
  test_dataset = datasets.MNIST(
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
  data_loader =
  test_loss = 0
  for data, target in data_loader:
      data, target =,
      output = loaded_model(data)
      test_loss += F.nll_loss(output, target)
  test_loss /= len(data_loader.dataset)
  print("Average test loss: {}".format(test_loss.item()))
Log directory: /dbfs/ml/horovod_pytorch/1610480453.3483307/MNISTDemo

Run the train function you just created to train a model on the driver node.

train(learning_rate = 0.001)
Downloading to data/MNIST/raw/train-images-idx3-ubyte.gz 0it [00:00, ?it/s] 0%| | 0/9912422 [00:00<?, ?it/s] 0%| | 49152/9912422 [00:00<00:33, 296316.61it/s] 2%|▏ | 212992/9912422 [00:00<00:25, 382334.50it/s] 7%|▋ | 712704/9912422 [00:00<00:17, 527564.85it/s] 20%|█▉ | 1982464/9912422 [00:00<00:10, 740441.11it/s] 55%|█████▌ | 5455872/9912422 [00:00<00:04, 1045658.85it/s] 93%|█████████▎| 9175040/9912422 [00:00<00:00, 1473744.19it/s] 9920512it [00:00, 10000334.83it/s] Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw Downloading to data/MNIST/raw/train-labels-idx1-ubyte.gz 0it [00:00, ?it/s] 0%| | 0/28881 [00:00<?, ?it/s] 32768it [00:00, 132541.29it/s] Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw Downloading to data/MNIST/raw/t10k-images-idx3-ubyte.gz 0it [00:00, ?it/s] 0%| | 0/1648877 [00:00<?, ?it/s] 3%|▎ | 49152/1648877 [00:00<00:04, 327216.91it/s] 13%|█▎ | 212992/1648877 [00:00<00:03, 418048.68it/s] 53%|█████▎ | 876544/1648877 [00:00<00:01, 576544.06it/s] 1654784it [00:00, 2556061.17it/s] Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw Downloading to data/MNIST/raw/t10k-labels-idx1-ubyte.gz 0it [00:00, ?it/s] 0%| | 0/4542 [00:00<?, ?it/s] 8192it [00:00, 47109.95it/s] Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw Processing... /opt/conda/conda-bld/pytorch_1587428190859/work/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. Done! /local_disk0/tmp/1610480446960-0/ UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument. from py4j.java_collections import ListConverter Train Epoch: 1 [0/60000 (0%)] Loss: 2.308137 Train Epoch: 1 [10000/60000 (17%)] Loss: 2.281827 Train Epoch: 1 [20000/60000 (33%)] Loss: 2.273307 Train Epoch: 1 [30000/60000 (50%)] Loss: 2.255614 Train Epoch: 1 [40000/60000 (67%)] Loss: 2.248265 Train Epoch: 1 [50000/60000 (83%)] Loss: 2.225026 Train Epoch: 2 [0/60000 (0%)] Loss: 2.201678 Train Epoch: 2 [10000/60000 (17%)] Loss: 2.169746 Train Epoch: 2 [20000/60000 (33%)] Loss: 2.087094 Train Epoch: 2 [30000/60000 (50%)] Loss: 1.892635 Train Epoch: 2 [40000/60000 (67%)] Loss: 1.856333 Train Epoch: 2 [50000/60000 (83%)] Loss: 1.815048 Train Epoch: 3 [0/60000 (0%)] Loss: 1.505360 Train Epoch: 3 [10000/60000 (17%)] Loss: 1.550015 Train Epoch: 3 [20000/60000 (33%)] Loss: 1.392060 Train Epoch: 3 [30000/60000 (50%)] Loss: 1.282646 Train Epoch: 3 [40000/60000 (67%)] Loss: 1.298241 Train Epoch: 3 [50000/60000 (83%)] Loss: 1.197927