Pular para o conteúdo principal

PyTorch

O projeto PyTorch é um pacote Python que fornece computação tensorial acelerada por GPU e funcionalidades de alto nível para a criação de redes de aprendizagem profunda. Para obter detalhes sobre a licença, consulte o documento de licença do PyTorch no GitHub.

Para monitorar e depurar seus modelos PyTorch, considere o uso do TensorBoard.

O PyTorch está incluído no Databricks Runtime for Machine Learning. Se o senhor estiver usando o Databricks Runtime, consulte Instalar o PyTorch para obter instruções sobre como instalar o PyTorch.

nota

Este não é um guia completo do PyTorch. Para obter mais informações, consulte o sitePyTorch.

Treinamento distribuído e de nó único

Para testar e migrar o fluxo de trabalho de uma única máquina, use um clustering de nó único.

Para conhecer as opções de treinamento distribuído para aprendizagem profunda, consulte Treinamento distribuído.

Exemplo de notebook

PyTorch Caderno de anotações

Open notebook in new tab

Instalar o PyTorch

Databricks Runtime para ML

Databricks Runtime for Machine Learning inclui PyTorch para que o senhor possa criar o clustering e começar a usar PyTorch. Para saber a versão do PyTorch instalada na versão Databricks Runtime ML que o senhor está usando, consulte as notas sobre a versão.

Databricks Runtime

A Databricks recomenda que o senhor use o PyTorch incluído no Databricks Runtime for Machine Learning. Entretanto, se o senhor precisar usar o padrão Databricks Runtime, PyTorch pode ser instalado como uma bibliotecaDatabricks PyPI. O exemplo a seguir mostra como instalar o PyTorch 1.5.0:

  • No clustering de GPU, instale pytorch e torchvision especificando o seguinte:

    • torch==1.5.0
    • torchvision==0.6.0
  • No clustering da CPU, instale pytorch e torchvision usando os seguintes arquivos Python wheel:

    https://download.pytorch.org/whl/cpu/torch-1.5.0%2Bcpu-cp37-cp37m-linux_x86_64.whl

    https://download.pytorch.org/whl/cpu/torchvision-0.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl

Erros e solução de problemas para o PyTorch distribuído

As seções a seguir descrevem mensagens de erro comuns e orientações de solução de problemas para as classes: PyTorch DataParallel ou PyTorch DistributedDataParallel. A maioria desses erros provavelmente pode ser resolvida com o TorchDistributorque está disponível em Databricks Runtime ML 13.0 e acima. No entanto, se o site TorchDistributor não for uma solução viável, também são fornecidas soluções recomendadas em cada seção.

A seguir, um exemplo de como usar o TorchDistributor:

Python

from pyspark.ml.torch.distributor import TorchDistributor

def train_fn(learning_rate):
# ...

num_processes=2
distributor = TorchDistributor(num_processes=num_processes, local_mode=True)

distributor.run(train_fn, 1e-3)

processo 0 encerrado com o código de saída 1

O seguinte erro pode ocorrer ao usar o Notebook em Databricks ou localmente:

process 0 terminated with exit code 1

Para evitar esse erro, use torch.multiprocessing.start_processes com start_method=fork em vez de torch.multiprocessing.spawn.

Por exemplo:

Python
import torch

def train_fn(rank, learning_rate):
# required setup, e.g. setup(rank)
# ...

num_processes = 2
torch.multiprocessing.start_processes(train_fn, args=(1e-3,), nprocs=num_processes, start_method="fork")

O soquete do servidor falhou ao se vincular à porta

O seguinte erro aparece quando o senhor reinicia o treinamento distribuído após interromper a célula durante o treinamento:

The server socket has failed to bind to [::]:{PORT NUMBER} (errno: 98 - Address already in use).

Para corrigir o problema, reinicie o clustering. Se a reinicialização não resolver o problema, pode haver um erro no código da função de treinamento.

Erros relacionados ao CUDA

O senhor pode ter problemas adicionais com,CUDA pois start_method=”fork” não CUDAé compatível com. O uso de qualquer comando .cuda em qualquer célula pode levar a falhas. Para evitar esses erros, adicione a seguinte verificação antes de chamar torch.multiprocessing.start_method:

Python
if torch.cuda.is_initialized():
raise Exception("CUDA was initialized; distributed training will fail.") # or something similar