

TensorFlow tutorial - MNIST For ML Beginners

This notebook demonstrates how to use TensorFlow on the Spark driver node to fit a neural network on MNIST handwritten digit recognition data.


  • A GPU-enabled cluster.
  • TensorFlow 1.15 or 2.x with GPU support installed manually.

The content of this notebook is adapted from TensorFlow project under Apache 2.0 license with slight modification to run on Databricks. Thanks to the developers of TensorFlow for this example!

import tensorflow as tf
from tensorflow.keras import models, layers, datasets
import datetime, uuid
# Verify that we're using TensorFlow 2.x or 1.15
assert tf.__version__.startswith("2.") or tf.__version__.startswith("1.15")

Load the data (this step may take a while)

(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Define the model

model = models.Sequential([
  layers.Flatten(input_shape=(28, 28)),
  layers.Dense(10, activation='softmax')

Define loss and optimizer


Start TensorBoard so you can monitor training progress.

# Define a user unique directory in DBFS
  username = dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags().apply('user')
  username = str(uuid.uuid1()).replace("-", "")
experiment_log_dir = "/dbfs/user/{}/tensorboard_log_dir/".format(username)
%load_ext tensorboard
%tensorboard --logdir $experiment_log_dir

Train the model in batches

run_log_dir = experiment_log_dir + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=run_log_dir, histogram_freq=1)
model.fit(x_train, y_train, epochs=5, batch_size=64, callbacks=[tensorboard_callback])

Test the trained model. The final accuracy is reported at the bottom. You can compare it with the accuracy reported by the other frameworks!

model.evaluate(x_test,  y_test, verbose=2)

TensorBoard stays active after your training is finished so you can view a summary of the process. Detach your notebook to stop the Tensorboard.

(Optional) Remove your log files from DBFS.

dbutils.fs.rm(experiment_log_dir.replace("/dbfs",""), recurse=True)