mnist-pytorch(Python)
Loading...

Distributed deep learning training using PyTorch with HorovodRunner for MNIST

This notebook illustrates the use of HorovodRunner for distributed training using PyTorch. It first shows how to train a model on a single node, and then shows how to adapt the code using HorovodRunner for distributed training. The notebook runs on CPU and GPU clusters.

Requirements

Databricks Runtime 7.0 ML or above.
HorovodRunner is designed to improve model training performance on clusters with multiple workers, but multiple workers are not required to run this notebook.

Set up checkpoint location

The next cell creates a directory for saved checkpoint models. Databricks recommends saving training data under dbfs:/ml, which maps to file:/dbfs/ml on driver and worker nodes.

PYTORCH_DIR = '/dbfs/ml/horovod_pytorch'

Prepare single-node code

First, create single-node PyTorch code. This is modified from the Horovod PyTorch MNIST Example.

Define a simple convolutional network

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
 
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

Configure single-node training

# Specify training parameters
batch_size = 100
num_epochs = 3
momentum = 0.5
log_interval = 100
def train_one_epoch(model, device, data_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(data_loader) * len(data),
                100. * batch_idx / len(data_loader), loss.item()))

Create methods for saving and loading model checkpoints

def save_checkpoint(log_dir, model, optimizer, epoch):
  filepath = log_dir + '/checkpoint-{epoch}.pth.tar'.format(epoch=epoch)
  state = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
  }
  torch.save(state, filepath)
  
def load_checkpoint(log_dir, epoch=num_epochs):
  filepath = log_dir + '/checkpoint-{epoch}.pth.tar'.format(epoch=epoch)
  return torch.load(filepath)
 
def create_log_dir():
  log_dir = os.path.join(PYTORCH_DIR, str(time()), 'MNISTDemo')
  os.makedirs(log_dir)
  return log_dir

Run single-node training with PyTorch

import torch.optim as optim
from torchvision import datasets, transforms
from time import time
import os
 
single_node_log_dir = create_log_dir()
print("Log directory:", single_node_log_dir)
 
def train(learning_rate):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
  train_dataset = datasets.MNIST(
    'data', 
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
  data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
 
  model = Net().to(device)
 
  optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
 
  for epoch in range(1, num_epochs + 1):
    train_one_epoch(model, device, data_loader, optimizer, epoch)
    save_checkpoint(single_node_log_dir, model, optimizer, epoch)
 
    
def test(log_dir):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  loaded_model = Net().to(device)
  
  checkpoint = load_checkpoint(log_dir)
  loaded_model.load_state_dict(checkpoint['model'])
  loaded_model.eval()
 
  test_dataset = datasets.MNIST(
    'data', 
    train=False,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
  data_loader = torch.utils.data.DataLoader(test_dataset)
 
  test_loss = 0
  for data, target in data_loader:
      data, target = data.to(device), target.to(device)
      output = loaded_model(data)
      test_loss += F.nll_loss(output, target)
  
  test_loss /= len(data_loader.dataset)
  print("Average test loss: {}".format(test_loss.item()))
Log directory: /dbfs/ml/horovod_pytorch/1610480453.3483307/MNISTDemo

Run the train function you just created to train a model on the driver node.

train(learning_rate = 0.001)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz 0it [00:00, ?it/s] 0%| | 0/9912422 [00:00<?, ?it/s] 0%| | 49152/9912422 [00:00<00:33, 296316.61it/s] 2%|▏ | 212992/9912422 [00:00<00:25, 382334.50it/s] 7%|▋ | 712704/9912422 [00:00<00:17, 527564.85it/s] 20%|█▉ | 1982464/9912422 [00:00<00:10, 740441.11it/s] 55%|█████▌ | 5455872/9912422 [00:00<00:04, 1045658.85it/s] 93%|█████████▎| 9175040/9912422 [00:00<00:00, 1473744.19it/s] 9920512it [00:00, 10000334.83it/s] Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz 0it [00:00, ?it/s] 0%| | 0/28881 [00:00<?, ?it/s] 32768it [00:00, 132541.29it/s] Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz 0it [00:00, ?it/s] 0%| | 0/1648877 [00:00<?, ?it/s] 3%|▎ | 49152/1648877 [00:00<00:04, 327216.91it/s] 13%|█▎ | 212992/1648877 [00:00<00:03, 418048.68it/s] 53%|█████▎ | 876544/1648877 [00:00<00:01, 576544.06it/s] 1654784it [00:00, 2556061.17it/s] Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz 0it [00:00, ?it/s] 0%| | 0/4542 [00:00<?, ?it/s] 8192it [00:00, 47109.95it/s] Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw Processing... /opt/conda/conda-bld/pytorch_1587428190859/work/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. Done! /local_disk0/tmp/1610480446960-0/PythonShell.py:21: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument. from py4j.java_collections import ListConverter Train Epoch: 1 [0/60000 (0%)] Loss: 2.308137 Train Epoch: 1 [10000/60000 (17%)] Loss: 2.281827 Train Epoch: 1 [20000/60000 (33%)] Loss: 2.273307 Train Epoch: 1 [30000/60000 (50%)] Loss: 2.255614 Train Epoch: 1 [40000/60000 (67%)] Loss: 2.248265 Train Epoch: 1 [50000/60000 (83%)] Loss: 2.225026 Train Epoch: 2 [0/60000 (0%)] Loss: 2.201678 Train Epoch: 2 [10000/60000 (17%)] Loss: 2.169746 Train Epoch: 2 [20000/60000 (33%)] Loss: 2.087094 Train Epoch: 2 [30000/60000 (50%)] Loss: 1.892635 Train Epoch: 2 [40000/60000 (67%)] Loss: 1.856333 Train Epoch: 2 [50000/60000 (83%)] Loss: 1.815048 Train Epoch: 3 [0/60000 (0%)] Loss: 1.505360 Train Epoch: 3 [10000/60000 (17%)] Loss: 1.550015 Train Epoch: 3 [20000/60000 (33%)] Loss: 1.392060 Train Epoch: 3 [30000/60000 (50%)] Loss: 1.282646 Train Epoch: 3 [40000/60000 (67%)] Loss: 1.298241 Train Epoch: 3 [50000/60000 (83%)] Loss: 1.197927

Load and use the model

test(single_node_log_dir)
/local_disk0/tmp/1610480446960-0/PythonShell.py:21: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument. from py4j.java_collections import ListConverter Average test loss: 0.6969888210296631

Migrate to HorovodRunner

HorovodRunner takes a Python method that contains deep learning training code with Horovod hooks. HorovodRunner pickles the method on the driver and distributes it to Spark workers. A Horovod MPI job is embedded as a Spark job using barrier execution mode.

import horovod.torch as hvd
from sparkdl import HorovodRunner
hvd_log_dir = create_log_dir()
print("Log directory:", hvd_log_dir)
 
def train_hvd(learning_rate):
  
  # Initialize Horovod
  hvd.init()  
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  
  if device.type == 'cuda':
    # Pin GPU to local rank
    torch.cuda.set_device(hvd.local_rank())
 
  train_dataset = datasets.MNIST(
    # Use different root directory for each worker to avoid conflicts
    root='data-%d'% hvd.rank(),  
    train=True, 
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
  )
 
  from torch.utils.data.distributed import DistributedSampler
  
  # Configure the sampler so that each worker gets a distinct sample of the input dataset
  train_sampler = DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
  # Use train_sampler to load a different sample of data on each worker
  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
 
  model = Net().to(device)
  
  # The effective batch size in synchronous distributed training is scaled by the number of workers
  # Increase learning_rate to compensate for the increased batch size
  optimizer = optim.SGD(model.parameters(), lr=learning_rate * hvd.size(), momentum=momentum)
 
  # Wrap the local optimizer with hvd.DistributedOptimizer so that Horovod handles the distributed optimization
  optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
  
  # Broadcast initial parameters so all workers start with the same parameters
  hvd.broadcast_parameters(model.state_dict(), root_rank=0)
 
  for epoch in range(1, num_epochs + 1):
    train_one_epoch(model, device, train_loader, optimizer, epoch)
    # Save checkpoints only on worker 0 to prevent conflicts between workers
    if hvd.rank() == 0:
      save_checkpoint(hvd_log_dir, model, optimizer, epoch)
Log directory: /dbfs/ml/horovod_pytorch/1610480588.0704386/MNISTDemo

Now that you have defined a training function with Horovod, you can use HorovodRunner to distribute the work of training the model.

The HorovodRunner parameter np sets the number of processes. This example uses a cluster with two workers, each with a single GPU, so set np=2. (If you use np=-1, HorovodRunner trains using a single process on the driver node.)

hr = HorovodRunner(np=2) 
hr.run(train_hvd, learning_rate = 0.001)
HorovodRunner will stream all training logs to notebook cell output. If there are too many logs, you can adjust the log level in your train method. Or you can set driver_log_verbosity to 'log_callback_only' and use a HorovodRunner log callback on the first worker to get concise progress updates. The global names read or written to by the pickled function are {'range', 'hvd', 'datasets', 'batch_size', 'num_epochs', 'Net', 'momentum', 'save_checkpoint', 'torch', 'transforms', 'optim', 'hvd_log_dir', 'train_one_epoch'}. The pickled object size is 4532 bytes. ### How to enable Horovod Timeline? ### HorovodRunner has the ability to record the timeline of its activity with Horovod Timeline. To record a Horovod Timeline, set the `HOROVOD_TIMELINE` environment variable to the location of the timeline file to be created. You can then open the timeline file using the chrome://tracing facility of the Chrome browser. /databricks/spark/python/pyspark/sql/context.py:77: DeprecationWarning: Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead. DeprecationWarning) Start training. Warning: Permanently added '10.97.243.250' (ECDSA) to the list of known hosts. [1,0]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data-0/MNIST/raw/train-images-idx3-ubyte.gz [1,1]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data-1/MNIST/raw/train-images-idx3-ubyte.gz [1,0]<stderr>: 0it [00:00, ?it/s][1,1]<stderr>: 0it [00:00, ?it/s][1,0]<stderr>: 0%| | 0/9912422 [00:00<?, ?it/s][1,1]<stderr>: 0%| | 0/9912422 [00:00<?, ?it/s][1,0]<stderr>: 0%| | 16384/9912422 [00:00<01:02, 158267.97it/s][1,1]<stderr>: 0%| | 49152/9912422 [00:00<00:33, 291495.89it/s][1,0]<stderr>: 1%| | 98304/9912422 [00:00<00:48, 204066.65it/s][1,1]<stderr>: 2%|▏ | 212992/9912422 [00:00<00:25, 377496.96it/s][1,0]<stderr>: 4%|▍ | 434176/9912422 [00:00<00:33, 281744.43it/s][1,1]<stderr>: 9%|▉ | 876544/9912422 [00:00<00:17, 520494.76it/s][1,0]<stderr>: 18%|█▊ | 1753088/9912422 [00:00<00:20, 397550.25it/s][1,1]<stderr>: 36%|███▌ | 3522560/9912422 [00:00<00:08, 736152.72it/s][1,0]<stderr>: 56%|█████▌ | 5513216/9912422 [00:00<00:07, 565364.73it/s][1,1]<stderr>: 71%|███████ | 7012352/9912422 [00:00<00:02, 1042223.13it/s][1,0]<stderr>: 77%|███████▋ | 7651328/9912422 [00:00<00:02, 798553.38it/s][1,1]<stderr>: 96%|█████████▌| 9469952/9912422 [00:01<00:00, 1459577.91it/s][1,1]<stderr>: 9920512it [00:01, 9881680.80it/s] [1,0]<stderr>: 99%|█████████▉| 9789440/9912422 [00:01<00:00, 1122753.85it/s][1,0]<stderr>: 9920512it [00:01, 9721397.72it/s] [1,1]<stdout>:Extracting data-1/MNIST/raw/train-images-idx3-ubyte.gz to data-1/MNIST/raw [1,0]<stdout>:Extracting data-0/MNIST/raw/train-images-idx3-ubyte.gz to data-0/MNIST/raw [1,1]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data-1/MNIST/raw/train-labels-idx1-ubyte.gz [1,1]<stderr>: [1,1]<stderr>: 0it [00:00, ?it/s][1,0]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data-0/MNIST/raw/train-labels-idx1-ubyte.gz [1,0]<stderr>: 0it [00:00, ?it/s][1,1]<stderr>: 0%| | 0/28881 [00:00<?, ?it/s][1,0]<stderr>: 0%| | 0/28881 [00:00<?, ?it/s][1,1]<stderr>: 32768it [00:00, 133212.39it/s] [1,1]<stdout>:Extracting data-1/MNIST/raw/train-labels-idx1-ubyte.gz to data-1/MNIST/raw [1,1]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data-1/MNIST/raw/t10k-images-idx3-ubyte.gz [1,1]<stderr>: [1,1]<stderr>: 0it [00:00, ?it/s][1,0]<stderr>: 32768it [00:00, 128155.50it/s] [1,0]<stdout>:Extracting data-0/MNIST/raw/train-labels-idx1-ubyte.gz to data-0/MNIST/raw [1,0]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data-0/MNIST/raw/t10k-images-idx3-ubyte.gz [1,0]<stderr>: 0it [00:00, ?it/s][1,1]<stderr>: 0%| | 0/1648877 [00:00<?, ?it/s][1,0]<stderr>: 0%| | 0/1648877 [00:00<?, ?it/s][1,1]<stderr>: 3%|▎ | 49152/1648877 [00:00<00:05, 298370.83it/s][1,0]<stderr>: 3%|▎ | 49152/1648877 [00:00<00:05, 294075.97it/s][1,1]<stderr>: 13%|█▎ | 212992/1648877 [00:00<00:03, 384481.22it/s][1,0]<stderr>: 13%|█▎ | 212992/1648877 [00:00<00:03, 379248.58it/s][1,1]<stderr>: 53%|█████▎ | 876544/1648877 [00:00<00:01, 531508.12it/s][1,1]<stderr>: 1654784it [00:00, 2584121.89it/s] [1,0]<stderr>: 53%|█████▎ | 876544/1648877 [00:00<00:01, 524642.44it/s][1,1]<stdout>:Extracting data-1/MNIST/raw/t10k-images-idx3-ubyte.gz to data-1/MNIST/raw [1,0]<stderr>: 1654784it [00:00, 2582083.83it/s] [1,0]<stdout>:Extracting data-0/MNIST/raw/t10k-images-idx3-ubyte.gz to data-0/MNIST/raw [1,1]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data-1/MNIST/raw/t10k-labels-idx1-ubyte.gz [1,1]<stderr>: [1,1]<stderr>: 0it [00:00, ?it/s][1,0]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data-0/MNIST/raw/t10k-labels-idx1-ubyte.gz [1,0]<stderr>: 0it [00:00, ?it/s][1,0]<stderr>: 0%| | 0/4542 [00:00<?, ?it/s][1,0]<stderr>: 8192it [00:00, 49309.70it/s] [1,0]<stdout>:Extracting data-0/MNIST/raw/t10k-labels-idx1-ubyte.gz to data-0/MNIST/raw [1,0]<stdout>:Processing... [1,1]<stderr>: 0%| | 0/4542 [00:00<?, ?it/s][1,1]<stderr>: 8192it [00:00, 44839.02it/s] [1,1]<stdout>:Extracting data-1/MNIST/raw/t10k-labels-idx1-ubyte.gz to data-1/MNIST/raw [1,1]<stdout>:Processing... [1,0]<stderr>:/opt/conda/conda-bld/pytorch_1587428190859/work/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. [1,0]<stdout>:Done! [1,1]<stdout>:Done! [1,0]<stderr>:-c:21: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument. [1,0]<stdout>:Train Epoch: 1 [0/30000 (0%)] Loss: 2.322573 [1,1]<stdout>:Train Epoch: 1 [0/30000 (0%)] Loss: 2.352982 [1,0]<stdout>:Train Epoch: 1 [10000/30000 (33%)] Loss: 2.289413 [1,1]<stdout>:Train Epoch: 1 [10000/30000 (33%)] Loss: 2.301658 [1,0]<stdout>:Train Epoch: 1 [20000/30000 (67%)] Loss: 2.240432 [1,1]<stdout>:Train Epoch: 1 [20000/30000 (67%)] Loss: 2.248593 [1,1]<stdout>:Train Epoch: 2 [0/30000 (0%)] Loss: 2.158686 [1,0]<stdout>:Train Epoch: 2 [0/30000 (0%)] Loss: 2.151930 [1,1]<stdout>:Train Epoch: 2 [10000/30000 (33%)] Loss: 2.017957 [1,0]<stdout>:Train Epoch: 2 [10000/30000 (33%)] Loss: 2.074805 [1,0]<stdout>:Train Epoch: 2 [20000/30000 (67%)] Loss: 1.832516 [1,1]<stdout>:Train Epoch: 2 [20000/30000 (67%)] Loss: 1.880822 [1,0]<stdout>:Train Epoch: 3 [0/30000 (0%)] Loss: 1.660405 [1,1]<stdout>:Train Epoch: 3 [0/30000 (0%)] Loss: 1.764962 [1,0]<stdout>:Train Epoch: 3 [10000/30000 (33%)] Loss: 1.392239 [1,1]<stdout>:Train Epoch: 3 [10000/30000 (33%)] Loss: 1.504940 [1,0]<stdout>:Train Epoch: 3 [20000/30000 (67%)] Loss: 1.504644 [1,1]<stdout>:Train Epoch: 3 [20000/30000 (67%)] Loss: 1.326515 [1,1]<stderr>: [1,1]<stderr>:/opt/conda/conda-bld/pytorch_1587428190859/work/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. [1,1]<stderr>:-c:21: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
test(hvd_log_dir)
/local_disk0/tmp/1610480446960-0/PythonShell.py:21: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument. from py4j.java_collections import ListConverter Average test loss: 0.6904363036155701

Under the hood, HorovodRunner takes a Python method that contains deep learning training code with Horovod hooks. HorovodRunner pickles the method on the driver and distributes it to Spark workers. A Horovod MPI job is embedded as a Spark job using the barrier execution mode. The first executor collects the IP addresses of all task executors using BarrierTaskContext and triggers a Horovod job using mpirun. Each Python MPI process loads the pickled user program, deserializes it, and runs it.

For more information, see HorovodRunner API documentation.