Train a PyTorch model

PyTorch is a Python package that provides GPU accelerated tensor computation and high level functionalities for building deep learning networks.

The MLflow PyTorch notebook fits a neural network on MNIST handwritten digit recognition data. The run results are logged to an MLflow server. Training metrics and weights in TensorFlow event format are logged locally and then uploaded to the MLflow run’s artifact directory. Finally TensorBoard is started and reads the events logged locally.

This example runs on Databricks Runtime for Machine Learning and above. To install PyTorch on a cluster running Databricks Runtime 5.0 ML, run the PyTorch init script notebook notebook to create an init script named and configure your cluster with the init script. If you run on Databricks Runtime 5.1 ML (Beta) or above, you do not need to create the PyTorch init script and configure your cluster with the script.

If you want to run TensorBoard to read the artifacts uploaded to S3, see How to run TensorFlow on S3.

MLflow PyTorch model training notebook