Ensinar o modelo ResNet18 PyTorch usando Ray em computeGPU serverless .
Este Notebook demonstra o treinamento distribuído de um modelo PyTorch ResNet18 no dataset FashionMNIST usando Ray ensina e Ray Data em clusters de GPU remotos com múltiplos nós. O treinamento é executado em computeGPUserverless, que fornece acesso sob demanda a recursos de GPU sem a necessidade de gerenciamento de infraestrutura.
conceitos-chave:
- Ray ensinando : Uma biblioteca de treinamento distribuída que escala o treinamento PyTorch em várias GPUs e nós
- Ray Data : Uma biblioteca escalável para carregamento e pré-processamento de dados, otimizada para cargas de trabalho ML
- computeGPU sem servidor : Databricks- clusters de GPU gerenciais que escalam automaticamente com base nas demandas de carga de trabalho.
- Unity Catalog : Soluções unificadas de governança para armazenamento de modelos, artefatos e conjuntos de dados.
O Notebook aborda a configuração do armazenamento Unity Catalog , a configuração do Ray para treinamento com GPUs em vários nós, o registro de experimentos com MLflow e o registro de modelos para implantação.
Requisitos
compute GPU sem servidor com acelerador A10
Conecte o notebook à compute GPU serverless :
- Clique no dropdown "Conectar" na parte superior.
- Selecione GPU sem servidor .
- Abra o painel lateral Ambiente , localizado no lado direito do Notebook.
- Defina o acelerador para A10 .
- Selecione Aplicar e clique em Confirmar .
A célula a seguir instala o pacote Python necessário para Ray e MLflow.
%pip install ray==2.49.1
%pip install -U mlflow>3
%restart_python
Configure os locais de armazenamento Unity Catalog .
A célula a seguir cria widgets para especificar o caminho Unity Catalog para armazenar modelos e artefatos. Atualize os valores do widget para corresponder à configuração do seu workspace .
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "sgc-nightly")
dbutils.widgets.text("uc_model_name", "ray_pytorch_mnist")
dbutils.widgets.text("uc_volume", "datasets")
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
Após atualizar os valores do widget na parte superior do Notebook, execute a seguinte célula para aplicar as alterações.
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
Configure o armazenamento Ray nos Volumes Unity Catalog
O Ray requer um local de armazenamento persistente para pontos de verificação, logs e artefatos de treinamento. A célula a seguir cria um diretório em um Volume Unity Catalog para armazenar a execução do treinamento Ray.
import os
RAY_STORAGE = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/ray_runs"
if not os.path.exists(RAY_STORAGE):
print(f"Creating directory: {RAY_STORAGE}")
os.mkdir(RAY_STORAGE)
Configure o treinamento distribuído com o Ray.
A célula a seguir define uma função auxiliar para recuperar IDs de execução do MLflow por tag. Esta função busca a execução do experimento correspondente a uma tag específica por key-valor, o que é útil para acompanhamento e recuperação da execução do treinamento.
import mlflow
def get_run_id_by_tag(experiment_name, tag_key, tag_value):
client = mlflow.tracking.MlflowClient()
experiment = client.get_experiment_by_name(experiment_name)
if experiment:
runs = client.search_runs(
experiment_ids=[experiment.experiment_id],
filter_string=f"tags.{tag_key} = '{tag_value}'"
)
if runs:
return runs[0].info.run_id
return None
Defina a função de treinamento distribuído
A célula a seguir define uma função Ray decorada com @ray_launch que executa treinamento de modelo distribuído em um pool de GPUs remoto. Configuração principal:
gpus=8: Solicita 8 GPUs para treinamento distribuídogpu_type="A10": Especifica aceleradores de GPU A10remote=True: execução em compute GPU serverless em vez da computedo Notebook
A função usa o setup_mlflow do Ray para acompanhamento do experimento e se integra ao Unity Catalog para armazenamento de artefatos. Consulte a documentação de integração do Ray com o MLflow para obter mais detalhes.
from serverless_gpu.ray import ray_launch
from serverless_gpu.utils import get_mlflow_experiment_name
import uuid
# Setup mlflow experiment
username = spark.sql("SELECT session_user()").collect()[0][0]
experiment_name = f"/Users/{username}/Ray_on_AIR_PyTorch"
os.environ["MLFLOW_EXPERIMENT_NAME"] = experiment_name
mlflow_tracking_uri = "databricks" # or your custom tracking URI
tag_id = f"fashion-mnist-ray-{uuid.uuid4()}"
@ray_launch(gpus=8, gpu_type="A10", remote=True)
def my_ray_function():
# Your Ray code here
import os
import tempfile
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
import ray
import ray.train
import ray.train.torch
import ray.data
import mlflow
from ray.air.integrations.mlflow import MLflowLoggerCallback, setup_mlflow
from ray.air.config import RunConfig
def torch_dataset_to_ray(ds):
"""Converts a PyTorch dataset to a Ray dataset.
Args:
ds (torch.utils.data.Dataset): The PyTorch dataset to convert.
Returns:
ray.data.Dataset: The converted Ray dataset.
"""
items = []
for i in range(len(ds)):
x, y = ds[i]
items.append({"image": x, "label": y})
return ray.data.from_items(items)
def train_func():
# Initialize the model and modify the first convolutional layer to accept single-channel images
model = resnet18(num_classes=10)
model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
model = ray.train.torch.prepare_model(model)
# Define the loss function and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
# Get the training and validation dataset shards
train_shard = ray.train.get_dataset_shard("train")
val_shard = ray.train.get_dataset_shard("validation")
# Start an MLflow run on rank 0 inside the trainable; artifacts go to UC Volume
ctx = ray.train.get_context()
mlflow_worker = setup_mlflow(
# config={"lr": 0.001, "batch_size": 128},
tracking_uri=mlflow_tracking_uri,
registry_uri="databricks-uc",
experiment_name=experiment_name,
# run_name=tag_id,
# artifact_location=MLFLOW_ARTIFACT_ROOT,
create_experiment_if_not_exists=True,
tags={"tag_id": tag_id},
rank_zero_only=True,
)
# Training loop for a fixed number of epochs
for epoch in range(10):
model.train()
train_loss = 0.0
num_train_batches = 0
# Iterate over training batches
for batch in train_shard.iter_torch_batches(batch_size=128):
images = batch["image"]
labels = batch["label"]
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
num_train_batches += 1
train_loss /= max(num_train_batches, 1)
model.eval()
val_loss = 0.0
correct = 0
total = 0
num_val_batches = 0
# Iterate over validation batches
with torch.no_grad():
for batch in val_shard.iter_torch_batches(batch_size=128):
images = batch["image"]
labels = batch["label"]
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
num_val_batches += 1
val_loss /= max(num_val_batches, 1)
val_accuracy = correct / total if total > 0 else 0.0
# Log metrics and save model checkpoint
metrics = {
"train_loss": train_loss,
"eval_loss": val_loss,
"eval_accuracy": val_accuracy,
"epoch": epoch,
}
# Rank-0: save checkpoint, log to MLflow, and report Ray checkpoint
if ctx.get_world_rank() == 0:
# Save model state to a temp dir
with tempfile.TemporaryDirectory() as tmp:
state_path = os.path.join(tmp, "model.pt")
torch.save(
model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
state_path,
)
# Log to MLflow under UC Volume-backed experiment
mlflow_worker.log_artifact(state_path, artifact_path=f"checkpoints/epoch={epoch}")
# Hand the same state to Ray Train for resumability/best-checkpoint selection
checkpoint = ray.train.Checkpoint.from_directory(tmp)
ray.train.report(metrics, checkpoint=checkpoint)
# Also stream metrics to MLflow from rank 0
mlflow_worker.log_metrics(metrics, step=epoch)
else:
# Non‑zero ranks report metrics only
ray.train.report(metrics)
# Prepare datasets
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
data_dir = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()), "data")
full_train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
test_data = FashionMNIST(root=data_dir, train=False, download=True, transform=transform)
train_size = int(0.8 * len(full_train_data))
val_size = len(full_train_data) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_train_data, [train_size, val_size])
ray_train_ds = torch_dataset_to_ray(train_dataset)
ray_val_ds = torch_dataset_to_ray(val_dataset)
ray_test_ds = torch_dataset_to_ray(test_data)
# Scaling config
scaling_config = ray.train.ScalingConfig(
num_workers=int(ray.cluster_resources().get("GPU")),
use_gpu=True
)
# Add MLflowLoggerCallback to RunConfig
run_config = RunConfig(
name="ray_fashion_mnist_sgc_example",
storage_path=RAY_STORAGE, # persistence path to UC volume
checkpoint_config=ray.train.CheckpointConfig(num_to_keep=10),
callbacks=[
MLflowLoggerCallback(
tracking_uri=mlflow_tracking_uri,
experiment_name=experiment_name,
save_artifact=True,
tags={"tag_id": tag_id}
)
]
)
trainer = ray.train.torch.TorchTrainer(
train_func,
scaling_config=scaling_config,
run_config=run_config,
datasets={
"train": ray_train_ds,
"validation": ray_val_ds,
"test": ray_test_ds,
},
)
result = trainer.fit()
return result
execução do treinamento distribuído Job
A célula a seguir executa a função de treinamento em compute GPU remota serverless . O método .distributed() inicia a execução do Ray Job no recurso de GPU solicitado.
result = my_ray_function.distributed()
result
Registre o modelo no MLflow para inferência.
A célula seguinte recupera a execução mais recente do MLflow no experimento para acessar o ponto de verificação do modelo treinado.
import mlflow
client = mlflow.tracking.MlflowClient()
experiment = client.get_experiment_by_name(experiment_name)
latest_run = None
if experiment:
runs = client.search_runs(
experiment_ids=[experiment.experiment_id],
order_by=["attributes.start_time DESC"],
max_results=1
)
if runs:
latest_run = runs[0]
run_id = latest_run.info.run_id
A célula seguinte carrega o modelo treinado a partir do ponto de verificação, logs o no MLflow com um exemplo de entrada e avalia seu desempenho em uma amostra de dados de teste.
import os
import tempfile
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
import mlflow
import numpy as np
# Load the model weights from checkpoint
result = result[0] if type(result) == list else result
with result.checkpoint.as_directory() as checkpoint_dir:
model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
model = resnet18(num_classes=10)
model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
model.load_state_dict(model_state_dict)
# Prepare datasets
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
data_dir = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()), "data")
test_data = FashionMNIST(root=data_dir, train=False, download=True, transform=transform)
# Get a small batch for the input example
dataloader = torch.utils.data.DataLoader(test_data, batch_size=5, shuffle=True)
batch = next(iter(dataloader))
features = batch[0].numpy() # Images/features
targets = batch[1].numpy() # Labels/targets
# Configure MLflow to use Unity Catalog
mlflow.set_registry_uri("databricks-uc")
# Log the model with input example
with mlflow.start_run(run_id=run_id) as run:
logged_model = mlflow.pytorch.log_model(
model,
"model",
input_example=features,
)
loaded_model = mlflow.pytorch.load_model(logged_model.model_uri)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.to(device)
loaded_model.eval()
criterion = CrossEntropyLoss()
num_batches_to_eval = 5
total = 0
correct = 0
test_loss = 0.0
processed_batches = 0
with torch.no_grad():
for i, (images, labels) in enumerate(dataloader):
if i >= num_batches_to_eval:
break
images = images.to(device)
labels = labels.to(device)
outputs = loaded_model(images)
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
processed_batches += 1
accuracy = 100.0 * correct / max(total, 1)
avg_loss = test_loss / max(processed_batches, 1)
print("First-5-batch results")
print(f" Accuracy: {accuracy:.2f}%")
print(f" Avg loss: {avg_loss:.4f}")
A célula a seguir registra os modelos no Unity Catalog, tornando-os disponíveis para implantação e governança em todo o workspace.
# Configure MLflow to use Unity Catalog
mlflow.set_registry_uri("databricks-uc")
mlflow.register_model(model_uri=logged_model.model_uri, name=f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}")
Próximos passos
Explore recursos adicionais para treinamento distribuído e compute GPU serverless :
- Melhores práticas para computede GPU serverless
- Solução de problemas compute GPU serverless
- Treinamento distribuído com múltiplas GPUs e múltiplos nós
- Ray ensina documentação
- Registro de modelo MLflow com Unity Catalog