Skip to main content

Train PyTorch ResNet18 model using Ray on serverless GPU compute

This notebook demonstrates distributed training of a PyTorch ResNet18 model on the FashionMNIST dataset using Ray Train and Ray Data on remote multi-node GPU clusters. The training runs on serverless GPU compute, which provides on-demand access to GPU resources without managing infrastructure.

Key concepts:

  • Ray Train: A distributed training library that scales PyTorch training across multiple GPUs and nodes
  • Ray Data: A scalable data loading and preprocessing library optimized for ML workloads
  • Serverless GPU compute: Databricks-managed GPU clusters that automatically scale based on workload demands
  • Unity Catalog: Unified governance solution for storing models, artifacts, and datasets

The notebook covers setting up Unity Catalog storage, configuring Ray for multi-node GPU training, logging experiments with MLflow, and registering models for deployment.

Requirements

Serverless GPU compute with A10 accelerator

Connect the notebook to serverless GPU compute:

  1. Click the Connect dropdown at the top.
  2. Select Serverless GPU.
  3. Open the Environment side panel on the right side of the notebook.
  4. Set Accelerator to A10.
  5. Select Apply and click Confirm.

The following cell installs the required Python packages for Ray and MLflow.

Python
%pip install ray==2.49.1
%pip install -U mlflow>3
%restart_python

Configure Unity Catalog storage locations

The following cell creates widgets to specify the Unity Catalog path for storing models and artifacts. Update the widget values to match your workspace configuration.

Python
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "sgc-nightly")
dbutils.widgets.text("uc_model_name", "ray_pytorch_mnist")
dbutils.widgets.text("uc_volume", "datasets")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")

After updating the widget values at the top of the notebook, run the following cell to apply the changes.

Python

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")

Set up Ray storage in Unity Catalog Volumes

Ray requires a persistent storage location for checkpoints, logs, and training artifacts. The following cell creates a directory in a Unity Catalog Volume to store Ray training runs.

Python
import os

RAY_STORAGE = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/ray_runs"
if not os.path.exists(RAY_STORAGE):
print(f"Creating directory: {RAY_STORAGE}")
os.mkdir(RAY_STORAGE)

Configure distributed training with Ray

The following cell defines a helper function to retrieve MLflow run IDs by tag. This function searches for experiment runs matching a specific tag key-value pair, which is useful for tracking and retrieving training runs.

Python
import mlflow

def get_run_id_by_tag(experiment_name, tag_key, tag_value):
client = mlflow.tracking.MlflowClient()
experiment = client.get_experiment_by_name(experiment_name)
if experiment:
runs = client.search_runs(
experiment_ids=[experiment.experiment_id],
filter_string=f"tags.{tag_key} = '{tag_value}'"
)
if runs:
return runs[0].info.run_id
return None

Define the distributed training function

The following cell defines a Ray function decorated with @ray_launch that runs distributed model training on remote GPU pools. Key configuration:

  • gpus=8: Requests 8 GPUs for distributed training
  • gpu_type="A10": Specifies A10 GPU accelerators
  • remote=True: Runs on serverless GPU compute instead of the notebook's compute

The function uses Ray's setup_mlflow for experiment tracking and integrates with Unity Catalog for artifact storage. See the Ray MLflow integration documentation for more details.

Python
from serverless_gpu.ray import ray_launch
from serverless_gpu.utils import get_mlflow_experiment_name
import uuid

# Setup mlflow experiment
username = spark.sql("SELECT current_user()").collect()[0][0]
experiment_name = f"/Users/{username}/Ray_on_AIR_PyTorch"
os.environ["MLFLOW_EXPERIMENT_NAME"] = experiment_name
mlflow_tracking_uri = "databricks" # or your custom tracking URI
tag_id = f"fashion-mnist-ray-{uuid.uuid4()}"

@ray_launch(gpus=8, gpu_type="A10", remote=True)
def my_ray_function():

# Your Ray code here
import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray
import ray.train
import ray.train.torch
import ray.data

import mlflow
from ray.air.integrations.mlflow import MLflowLoggerCallback, setup_mlflow
from ray.air.config import RunConfig

def torch_dataset_to_ray(ds):
"""Converts a PyTorch dataset to a Ray dataset.

Args:
ds (torch.utils.data.Dataset): The PyTorch dataset to convert.

Returns:
ray.data.Dataset: The converted Ray dataset.
"""
items = []
for i in range(len(ds)):
x, y = ds[i]
items.append({"image": x, "label": y})
return ray.data.from_items(items)

def train_func():
# Initialize the model and modify the first convolutional layer to accept single-channel images
model = resnet18(num_classes=10)
model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
model = ray.train.torch.prepare_model(model)

# Define the loss function and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

# Get the training and validation dataset shards
train_shard = ray.train.get_dataset_shard("train")
val_shard = ray.train.get_dataset_shard("validation")

# Start an MLflow run on rank 0 inside the trainable; artifacts go to UC Volume
ctx = ray.train.get_context()
mlflow_worker = setup_mlflow(
# config={"lr": 0.001, "batch_size": 128},
tracking_uri=mlflow_tracking_uri,
registry_uri="databricks-uc",
experiment_name=experiment_name,
# run_name=tag_id,
# artifact_location=MLFLOW_ARTIFACT_ROOT,
create_experiment_if_not_exists=True,
tags={"tag_id": tag_id},
rank_zero_only=True,
)

# Training loop for a fixed number of epochs
for epoch in range(10):
model.train()
train_loss = 0.0
num_train_batches = 0

# Iterate over training batches
for batch in train_shard.iter_torch_batches(batch_size=128):
images = batch["image"]
labels = batch["label"]
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
num_train_batches += 1
train_loss /= max(num_train_batches, 1)

model.eval()
val_loss = 0.0
correct = 0
total = 0
num_val_batches = 0

# Iterate over validation batches
with torch.no_grad():
for batch in val_shard.iter_torch_batches(batch_size=128):
images = batch["image"]
labels = batch["label"]
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
num_val_batches += 1
val_loss /= max(num_val_batches, 1)
val_accuracy = correct / total if total > 0 else 0.0

# Log metrics and save model checkpoint
metrics = {
"train_loss": train_loss,
"eval_loss": val_loss,
"eval_accuracy": val_accuracy,
"epoch": epoch,
}

# Rank-0: save checkpoint, log to MLflow, and report Ray checkpoint
if ctx.get_world_rank() == 0:
# Save model state to a temp dir
with tempfile.TemporaryDirectory() as tmp:
state_path = os.path.join(tmp, "model.pt")
torch.save(
model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
state_path,
)
# Log to MLflow under UC Volume-backed experiment
mlflow_worker.log_artifact(state_path, artifact_path=f"checkpoints/epoch={epoch}")
# Hand the same state to Ray Train for resumability/best-checkpoint selection
checkpoint = ray.train.Checkpoint.from_directory(tmp)
ray.train.report(metrics, checkpoint=checkpoint)


# Also stream metrics to MLflow from rank 0
mlflow_worker.log_metrics(metrics, step=epoch)

else:
# Non‑zero ranks report metrics only
ray.train.report(metrics)

# Prepare datasets
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
data_dir = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()), "data")
full_train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
test_data = FashionMNIST(root=data_dir, train=False, download=True, transform=transform)
train_size = int(0.8 * len(full_train_data))
val_size = len(full_train_data) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_train_data, [train_size, val_size])
ray_train_ds = torch_dataset_to_ray(train_dataset)
ray_val_ds = torch_dataset_to_ray(val_dataset)
ray_test_ds = torch_dataset_to_ray(test_data)

# Scaling config
scaling_config = ray.train.ScalingConfig(
num_workers=int(ray.cluster_resources().get("GPU")),
use_gpu=True
)

# Add MLflowLoggerCallback to RunConfig
run_config = RunConfig(
name="ray_fashion_mnist_sgc_example",
storage_path=RAY_STORAGE, # persistence path to UC volume
checkpoint_config=ray.train.CheckpointConfig(num_to_keep=10),
callbacks=[
MLflowLoggerCallback(
tracking_uri=mlflow_tracking_uri,
experiment_name=experiment_name,
save_artifact=True,
tags={"tag_id": tag_id}
)
]
)

trainer = ray.train.torch.TorchTrainer(
train_func,
scaling_config=scaling_config,
run_config=run_config,
datasets={
"train": ray_train_ds,
"validation": ray_val_ds,
"test": ray_test_ds,
},
)
result = trainer.fit()

return result

Run the distributed training job

The following cell executes the training function on remote serverless GPU compute. The .distributed() method triggers the Ray job to run on the requested GPU resources.

Python
result = my_ray_function.distributed()
result

Register the model to MLflow for inference

The following cell retrieves the most recent MLflow run from the experiment to access the trained model checkpoint.

Python
import mlflow

client = mlflow.tracking.MlflowClient()
experiment = client.get_experiment_by_name(experiment_name)
latest_run = None
if experiment:
runs = client.search_runs(
experiment_ids=[experiment.experiment_id],
order_by=["attributes.start_time DESC"],
max_results=1
)
if runs:
latest_run = runs[0]
run_id = latest_run.info.run_id

The following cell loads the trained model from the checkpoint, logs it to MLflow with an input example, and evaluates its performance on a sample of test data.

Python
import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import mlflow
import numpy as np

# Load the model weights from checkpoint
result = result[0] if type(result) == list else result
with result.checkpoint.as_directory() as checkpoint_dir:
model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
model = resnet18(num_classes=10)
model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
model.load_state_dict(model_state_dict)

# Prepare datasets
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
data_dir = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()), "data")
test_data = FashionMNIST(root=data_dir, train=False, download=True, transform=transform)

# Get a small batch for the input example
dataloader = torch.utils.data.DataLoader(test_data, batch_size=5, shuffle=True)
batch = next(iter(dataloader))
features = batch[0].numpy() # Images/features
targets = batch[1].numpy() # Labels/targets

# Configure MLflow to use Unity Catalog
mlflow.set_registry_uri("databricks-uc")

# Log the model with input example
with mlflow.start_run(run_id=run_id) as run:
logged_model = mlflow.pytorch.log_model(
model,
"model",
input_example=features,
)

loaded_model = mlflow.pytorch.load_model(logged_model.model_uri)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.to(device)
loaded_model.eval()

criterion = CrossEntropyLoss()
num_batches_to_eval = 5

total = 0
correct = 0
test_loss = 0.0
processed_batches = 0

with torch.no_grad():
for i, (images, labels) in enumerate(dataloader):
if i >= num_batches_to_eval:
break
images = images.to(device)
labels = labels.to(device)

outputs = loaded_model(images)
loss = criterion(outputs, labels)
test_loss += loss.item()

_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
processed_batches += 1

accuracy = 100.0 * correct / max(total, 1)
avg_loss = test_loss / max(processed_batches, 1)

print("First-5-batch results")
print(f" Accuracy: {accuracy:.2f}%")
print(f" Avg loss: {avg_loss:.4f}")

The following cell registers the logged model to Unity Catalog, making it available for deployment and governance across the workspace.

Python
# Configure MLflow to use Unity Catalog
mlflow.set_registry_uri("databricks-uc")
mlflow.register_model(model_uri=logged_model.model_uri, name=f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}")

Next steps

Explore additional resources for distributed training and serverless GPU compute:

Example notebook

Train PyTorch ResNet18 model using Ray on serverless GPU compute

Open notebook in new tab