Load TFRecord files for distributed DL(Python)

Loading...

Preparing MNIST data for Distributed DL

This notebook uses MNIST as an example to show how to load TFRecord files for distributed DL.

Before running this notebook, you must:

  1. Save the data to TFRecord at tfrecord_location.
import os
import tensorflow as tf

Step 0: Define the TFRecord save location.

We recommend saving training data under dbfs:/ml, which maps to file:/dbfs/ml on driver and worker nodes. dbfs:/ml is a special folder that provides high-performance I/O for deep learning workloads.

tfrecord_location = '/dbfs/ml/MNISTDemo/mnistData/'
name = "train.tfrecords"
filename = os.path.join(tfrecord_location, name)

Step 1: Create a TFRecordDataset as an input pipeline.

dataset = tf.data.TFRecordDataset(filename)

Step 2: Define a decoder to read and parse data.

def decode(serialized_example):
  """
  Parses an image and label from the given `serialized_example`.
  It is used as a map function for `dataset.map`
  """
  IMAGE_SIZE = 28
  IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
  
  # 1. define a parser
  features = tf.parse_single_example(
      serialized_example,
      # Defaults are not specified since both keys are required.
      features={
          'image_raw': tf.FixedLenFeature([], tf.string),
          'label': tf.FixedLenFeature([], tf.int64),
      })
 
  # 2. Convert the data
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  label = tf.cast(features['label'], tf.int32)
  # 3. reshape
  image.set_shape((IMAGE_PIXELS))
  return image, label

Parse the record into tensors with map. map takes a Python function and applies it to every sample.

dataset = dataset.map(decode)

Step 3: Preprocess the data.

def normalize(image, label):
  """Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
  image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
  return image, label
dataset = dataset.map(normalize)
batch_size = 1000
dataset = dataset.shuffle(1000 + 3 * batch_size )
dataset = dataset.repeat(2)
dataset = dataset.batch(batch_size)

Step 4: Create an iterator.

iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()

Use the dataset as an input to train the model. For a full example that uses TFRecord files as input for DL training, see the TensorFlow MNIST Example.

sess = tf.Session()
image_batch, label_batch = sess.run([image_batch, label_batch])
print(image_batch.shape)
print(label_batch.shape)