PyTorch
PyTorch project is a Python package that provides GPU accelerated tensor computation and high level functionalities for building deep learning networks. For licensing details, see the PyTorch license doc on GitHub.
To monitor and debug your PyTorch models, consider using TensorBoard.
PyTorch is included in Databricks Runtime for Machine Learning. If you are using Databricks Runtime, see Install PyTorch for instructions on installing PyTorch.
Note
This is not a comprehensive guide to PyTorch. For more information, see the PyTorch website.
Single node and distributed training
To test and migrate single-machine workflows, use a Single Node cluster.
For distributed training options for deep learning, see Distributed training.
Install PyTorch
Databricks Runtime for ML
Introduction to Databricks Runtime for Machine Learning includes PyTorch so you can create the cluster and start using PyTorch. For the version of PyTorch installed in the Databricks Runtime ML version you are using, see the release notes.
Databricks Runtime
Databricks recommends that you use the PyTorch included on Introduction to Databricks Runtime for Machine Learning. However, if you must use Databricks Runtime, PyTorch can be installed as a Databricks PyPI library. The following example shows how to install PyTorch 1.5.0:
On GPU clusters, install
pytorch
andtorchvision
by specifying the following:torch==1.5.0
torchvision==0.6.0
On CPU clusters, install
pytorch
andtorchvision
by using the following wheel files:https://download.pytorch.org/whl/cpu/torch-1.5.0%2Bcpu-cp37-cp37m-linux_x86_64.whl https://download.pytorch.org/whl/cpu/torchvision-0.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl