Train a PyTorch model

PyTorch is a Python package that provides GPU-accelerated tensor computation and high level functionality for building deep learning networks.

The MLflow PyTorch notebook fits a neural network on MNIST handwritten digit recognition data and logs run results to an MLflow server. It logs training metrics and weights in TensorFlow event format locally and then uploads them to the MLflow run’s artifact directory. Finally, it starts TensorBoard and reads the events logged locally.

When you’re ready you can deploy your model using Model serving with Databricks.

MLflow PyTorch model training notebook

Open notebook in new tab