Best practices for deep learning on Databricks

Databricks Machine Learning provides pre-built deep learning infrastructure with Databricks Runtime for Machine Learning. Databricks Runtime for Machine Learning includes the most common deep learning libraries like TensorFlow, PyTorch, and Keras and supporting libraries like Petastorm, Hyperopt, and Horovod. It also has built-in, pre-configured GPU support including drivers and supporting libraries.

Databricks Runtime ML also includes all of the capabilities of the Databricks workspace, such as cluster creation and management, library and environment management, code management with Databricks Repos, automation support including Databricks Jobs and APIs, and integrated MLflow for model development tracking and model deployment and serving.

With Databricks, you can use any library to create the logic to train your model. The preconfigured Databricks runtime makes it possible to easily scale common machine learning and deep learning steps. This article includes tips for deep learning on Databricks and information about built-in tools and libraries designed to optimize deep learning workloads such as:

Resource and environment management

Databricks helps you to both customize your deep learning environment and keep the environment consistent across users.

Customize the development environment

With Databricks Runtime, you can customize your development environment at the notebook, cluster, and job levels.

Use cluster policies

You can create cluster policies to guide data scientists to the right choices, such as using a Single Node cluster for development and using an autoscaling cluster for large jobs.

Best practices for loading data

Cloud data storage is typically not optimized for I/O, which can be a challenge for deep learning models that require large datasets. Databricks Runtime ML includes Delta Lake and Petastorm to optimize data throughput for deep learning applications.

Databricks recommends using Delta Lake tables for data storage. Delta Lake simplifies ETL and lets you access data efficiently. Especially for images, Delta Lake helps optimize ingestion for both training and inference. The reference solution for image applications provides an example of optimizing ETL for images using Delta Lake.

Petastorm provides APIs that let you prepare data in parquet format for use by TensorFlow, Keras, or PyTorch. The SparkConverter API provides Spark DataFrame integration. Petastorm also provides data sharding for distributed processing. See Load data using Petastorm for details.

Best practices for training deep learning models

Databricks recommends using the Machine Learning Runtime and MLflow tracking and autologging for all model training.

Start with a Single Node cluster

A Single Node (driver only) GPU cluster is typically fastest and most cost-effective for deep learning model development. One node with 4 GPUs is likely to be faster for deep learning training that 4 worker nodes with 1 GPU each. This is because distributed training incurs network communication overhead.

A Single Node cluster is a good option during fast, iterative development and for training models on small- to medium-size data. If your dataset is large enough to make training slow on a single machine, consider moving to multi-GPU and even distributed compute.

Use TensorBoard and Ganglia to monitor the training process

TensorBoard is preinstalled in Databricks Runtime ML. You can use it within a notebook or in a separate tab. See TensorBoard for details.

Ganglia is available in all Databricks runtimes. You can use it to examine network, processor, and memory usage to inspect for bottlenecks. See Ganglia metrics for details.

Optimize performance for deep learning

You can, and should, use deep learning performance optimization techniques on Databricks.

Early stopping

Early stopping monitors the value of a metric calculated on the validation set and stops training when the metric stops improving. This is a better approach than guessing at a good number of epochs to complete. Each deep learning library provides a native API for early stopping; for example, see the EarlyStopping callback APIs for TensorFlow/Keras and for PyTorch Lightning. For an example notebook, see Get started with TensorFlow Keras in Databricks.

Batch size tuning

Batch size tuning helps optimize GPU utilization. If the batch size is too small, the calculations cannot fully use the GPU capabilities. You can use Ganglia metrics to view GPU metrics.

Adjust the batch size in conjunction with the learning rate. A good rule of thumb is, when you increase the batch size by n, increase the learning rate by sqrt(n). When tuning manually, try changing batch size by a factor of 2 or 0.5. Then continue tuning to optimize performance, either manually or by testing a variety of hyperparameters using an automated tool like Hyperopt.

Transfer learning

With transfer learning, you start with a previously trained model and modify it as needed for your application. Transfer learning can significantly reduce the time required to train and tune a new model. See Featurization for transfer learning for more information and an example.

Move to distributed training

Databricks Runtime ML includes HorovodRunner, spark-tensorflow-distributor, and Hyperopt to facilitate the move from single-node to distributed training.


Horovod is an open-source project that scales deep learning training to multi-GPU or distributed computation. HorovodRunner, built by Databricks and included in Databricks Runtime ML, is a Horovod wrapper that provides Spark compatibility. The API lets you scale single-node code with minimal changes. HorovodRunner works with TensorFlow, Keras, and PyTorch.


spark-tensorflow-distributor is an open-source native package in TensorFlow for distributed training with TensorFlow on Spark clusters. See the example notebook.


Hyperopt provides adaptive hyperparameter tuning for machine learning. With the SparkTrials class, you can iteratively tune parameters for deep learning models in parallel across a cluster.

Best practices for inference

This section contains general tips about using models for inference with Databricks.

  • To minimize costs, consider both CPUs and inference-optimized GPUs such as the Amazon EC2 G4 and G5 instances. There is no clear recommendation, as the best choice depends on model size, data dimensions, and other variables.

  • Use MLflow to simplify deployment and model serving. MLflow can log any deep learning model, including custom preprocessing and postprocessing logic. Models registered in the MLflow Model Registry can be deployed for batch, streaming, or online inference.

Batch and streaming inference

Batch and streaming scoring supports high-throughput, low-cost scoring at latencies as low as minutes. For more information, see Offline (batch) predictions.

  • If you expect to access data for inference more than once, consider creating a preprocessing job to ETL the data into a Delta Lake table before running the inference job. This way, the cost of ingesting and preparing the data is spread across multiple reads of the data. Separating preprocessing from inference also allows you to select different hardware for each job to optimize cost and performance. For example, you might use CPUs for ETL and GPUs for inference.

  • Use Spark Pandas UDFs to scale batch and streaming inference across a cluster.

Online serving

The best option for low-latency serving is online serving behind a REST API. Databricks provides Databricks Model Serving for online inference.

MLflow provides APIs for deploying to various managed services for online inference, as well as APIs for creating Docker containers for custom serving solutions.

You can also use SageMaker Serving. See the example notebook and the MLflow documentation.