PyTorch project is a Python package that provides GPU accelerated tensor computation and high level functionalities for building deep learning networks. For licensing details, see the PyTorch license doc on GitHub.

In the sections below, we provide guidance on installing PyTorch on Databricks and give an example of running PyTorch programs. See Integrating Deep Learning Libraries with Apache Spark for an example of integrating a deep learning library with Spark.


This is not a comprehensive guide to PyTorch. Refer to the PyTorch website.

Install PyTorch


PyTorch is included in Databricks Runtime 5.1 ML and above, a machine learning runtime that provides a ready-to-go environment for machine learning and data science. Instead of installing PyTorch using the instructions below, you can simply create a cluster using Databricks Runtime ML. See Overview of Databricks Runtime for Machine Learning.

PyTorch can be installed as a Databricks PyPI library. Create two libraries: one with the wheel URL for your version of Python and torchvision. Databricks recommends version 0.4.1:

  • Python 2:<cuda-version>/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl
  • Python 3:<cuda-version>/torch-0.4.1-cp35-cp35m-linux_x86_64.whl

where <cuda-version> is 90 for Databricks Runtime 5.x and 4.x and 80 for Databricks Runtime 3.x.

Use PyTorch on a single node

To test and migrate single-machine PyTorch workflows, you can start with a driver-only cluster on Databricks by setting the number of workers to zero. Though Apache Spark is not functional under this setting, it is a cost-effective way to run single-machine PyTorch workflows.