Distributed Deep Learning with TensorFlowOnSpark


If you’re new to Tensorflow or deep learning, we recommend that you first work through the single-node Tensorflow guide.

In this guide we explore Yahoo’s TensorFlowOnSpark framework for distributed deep learning on Spark clusters, describing how it works and highlighting typical usage patterns. We then walk through a set of notebooks (referenced throughout the guide) to demonstrate how to train a neural network on MNIST in a distributed setting with TensorFlowOnSpark.

We’ve packaged the example notebooks for this guide into a DBC archive so you can easily import them into Databricks and run them yourself:

Download the example notebooks for this guide

What is TensorFlowOnSpark?

TensorFlowOnSpark, developed by Yahoo, is an open-source Python framework for launching Tensorflow-based distributed deep learning workflows on Spark clusters (see the original Yahoo blog post).

TensorFlowOnSpark simplifies deep learning training on Spark clusters by:

  • Abstracting away the need to manually specify a cluster configuration (mapping from CPUs/GPUs to Tensorflow tasks)
  • Providing APIs for feeding data from Spark RDDs to Tensorflow programs
  • Leveraging Spark’s built-in fault-tolerance (to recover from failures during TF training)

How does TensorFlowOnSpark work?

TensorFlowOnSpark launches distributed model training as a single Spark job. The entry point to model training is TFCluster.run (see its API documentation), which runs a user-specified Python function in each Spark executor, passing a context argument ctx. The context argument contains global and executor-specific information on the current distributed training run. For example, ctx.job_name indicates whether the current Spark executor should act as a Tensorflow worker or parameter server, while ctx.cluster_spec contains a global mapping from Tensorflow job name ("ps" or "worker") to CPU/GPU devices. The user-specified function is responsible for inspecting the context argument and running the appropriate logic (for example, running parameter server or worker code).

A Typical Workflow

TensorFlowOnSpark workflows typically consist of three parts: training, evaluation against held-out samples, and monitoring:


  • Runs on the Spark workers.
  • Copies a distinct subset of the training data from S3 to each worker’s local hard drive.
  • Fits model against local data.
  • Stores model checkpoints in DBFS.
  • Stores event files (containing information on training loss) persistently in DBFS.
  • Periodically syncs event files to the driver to be consumed by TensorBoard, which runs on the driver.


  • Runs on the Spark driver during training.
  • Stores event files (containing accuracy information on an evaluation set) on the driver to be read by TensorBoard, periodically syncing them to DBFS for persistent storage.


  • Runs on the Spark driver to read and display event files that contain metrics about the training run. Using TensorBoard in a distributed environment requires some additional setup, described in the example notebooks below.
  • See the single-node Tensorflow guide for an overview of TensorBoard.


We recommend training against data downloaded from S3 to the local disks of the workers. We’ve found this approach to be the most robust, although it limits the size of the training dataset to the total disk space available across all workers. If needed, you can increase the disk space available on your cluster by attaching EBS volumes (link). The notebooks below also include a utility function to download data from S3 to the worker disk.

Example Notebooks: Distributed Training on MNIST

The ensuing pages of this guide walk through a series of notebooks demonstrating how to train MNIST in a distributed setting using TensorFlowOnSpark.

  • Helper Notebooks (shared by model training & evaluation)

    • Data Ingest: Demonstrates how to download data from S3 and create a data ingest pipeline (load training data from disk into in-memory tensors) using tf.data APIs in Tensorflow. Our sample data is already stored in a public S3 bucket and has been split into training & validation sets.
    • Constructing the Model Graph: Demonstrates how to construct a Tensorflow graph for distributed model training; uses distributed Tensorflow primitives, but does not contain any TensorFlowOnSpark code.
    • Constants: Contains constants used for model training and evaluation (specifies hyperparameters, location of training & test data on S3, etc.)
  • Launching Model Training:

    To launch model training, you need only run this notebook. It uses the helpers defined in the preceding notebooks to build the model graph, then calls TensorFlowOnSpark APIs to launch distributed model training on the Spark workers.

  • Model Evaluation

    This notebook runs solely on the Spark driver and should be run concurrently with the Launching Model Training notebook. It downloads a validation dataset from S3, periodically loads the partially-trained model from checkpoint files stored on the driver, computes model accuracy on the validation dataset, and writes summary information to a local event file to be consumed by TensorBoard.

  • TensorBoard

    This notebook runs TensorBoard on the driver, consuming event files from a directory on the driver’s local filesystem.

Running the Example Notebooks

The example notebooks are available as a DBC archive. They are designed to run on clusters with multiple worker machines, and therefore will not run on Community Edition.

We recommend using at least two Spark workers with the following libraries:

If you use a CPU cluster, configure Spark to use a single executor per machine (this is the default setting on GPU-enabled clusters). You can do this by setting “spark.executor.cores” to “1” on the cluster creation page. See Spark Config for more info.

Building TensorFlowOnSpark

To build a TensorFlowOnSpark egg for use on Databricks, clone the project (link) and run python setup.py bdist_egg from the root directory. The example above has been tested against an egg built from TensorFlowOnSpark commit 9a46b288c9eef6646a49155a808058ff419efee6.

FAQ & Debugging Tips

My TensorFlowOnSpark program is hanging while “waiting for x reservations

This can happen if your cluster has fewer than the number of executors passed to TFCluster.run, in which case you can simply specify an appropriate number of executors. However, you might also encounter this issue on GPU clusters if your program encounters an error while running and you then attempt to rerun it. In this case, we recommend restarting your cluster.

More detail: The issue is caused by TensorFlowOnSpark running parameter server logic in a subprocess of a PySpark worker process. The parameter server subprocess acquires a GPU and then blocks. Since it’s a subprocess, the parameter server doesn’t die when its parent process is killed on task failure. Instead, it continues to hold a GPU, blocking future attempts to reserve a GPU.

How do I view logs generated during model training?

Model training runs on the Spark workers, and you can find training logs in the stderr of Spark workers. To view training logs, navigate to Clusters, click the current cluster, then navigate to Spark Cluster UI - Master. Then click an individual worker and view its stderr to see training logs.

Where are model checkpoints and event files stored in the example workflow?

Here’s a summary of what gets stored where:

  • Training Data: Stored on S3, copied to a local directory on each worker
  • Model checkpoints: Stored on DBFS
  • Event files: Stored on DBFS (generated locally by the driver and chief worker, then synced to DBFS). Event files are also periodically synced to a local directory on the driver, to be consumed by TensorBoard.

In our example, we define the destination directories of model checkpoints and event files in a single notebook. See Model Training & Evaluation Constants.

Can I run TensorFlowOnSpark on CPU-only clusters?

Yes, although it is likely more cost-effective to run your distributed training code on GPU clusters.