Save data from Spark DataFrames to TFRecords and load it using TensorFlow(Python)

Loading...

Preparing image data for Distributed DL

This notebook uses the flowers dataset from the TensorFlow team as an example to show how to save image data from Spark DataFrames to TFRecords and load it using TensorFlow.

The dataset contains flower photos stored in five subdirectories, one per class. It is hosted under Databricks Datasets dbfs:/databricks-datasets/flower_photos for easy access.

This notebook loads the flowers table, which contains the preprocessed flowers dataset, using the binary file data source.

from pyspark.sql.functions import col, pandas_udf
 
import os
import uuid
import tensorflow as tf

Save data from Spark DataFrames to TFRecords

Step 1: Load data using Spark.

df = spark.read.format("delta").load("/databricks-datasets/flowers/delta")
 
labels = df.select(col("label")).distinct().collect()
label_to_idx = {label: index for index, (label, ) in enumerate(sorted(labels))}
 
@pandas_udf("long")
def get_label_idx(labels):
  return labels.map(lambda label: label_to_idx[label])
 
df = df.withColumn("label_index", get_label_idx(col("label"))) \
  .select(col("content"), col("label_index")) \
  .limit(100)

Step 2: Save the data to TFRecord files.

name_uuid = str(uuid.uuid4())
path = '/ml/flowersData/df-{}.tfrecord'.format(name_uuid)
df.limit(100).write.format("tfrecords").mode("overwrite").save(path)
display(dbutils.fs.ls(path))
 
path
name
size
1
2
dbfs:/ml/flowersData/df-5d42cd4d-0fe5-4932-9fca-cd36477bd088.tfrecord/_SUCCESS
_SUCCESS
0
dbfs:/ml/flowersData/df-5d42cd4d-0fe5-4932-9fca-cd36477bd088.tfrecord/part-r-00000
part-r-00000
12490605

Showing all 2 rows.

Load TFRecords using TensorFlow

Step 1: Create a TFRecordDataset as an input pipeline.

filenames = [("/dbfs" + path + "/" + name) for name in os.listdir("/dbfs" + path) if name.startswith("part")]
dataset = tf.data.TFRecordDataset(filenames)

Step 2: Define a decoder to read, parse, and normalize the data.

Parse the record into tensors with map. map takes a Python function and applies it to every sample.

def decode_and_normalize(serialized_example):
  """
  Decode and normalize an image and label from the given `serialized_example`.
  It is used as a map function for `dataset.map`
  """
  IMAGE_SIZE = 224
  
  # 1. define a parser
  feature_dataset = tf.io.parse_single_example(
      serialized_example,
      # Defaults are not specified since both keys are required.
      features={
          'content': tf.io.FixedLenFeature([], tf.string),
          'label_index': tf.io.FixedLenFeature([], tf.int64),
      })
  # 2. decode the data
  image = tf.io.decode_jpeg(feature_dataset['content'])
  label = tf.cast(feature_dataset['label_index'], tf.int32)
  # 3. resize
  image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
  # 4. normalize the data
  image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
  return image, label
 
parsed_dataset = dataset.map(decode_and_normalize)

Use the dataset as an input to train the model.

For a full example that uses TFRecord files as input for DL training, see tf.data: Build TensorFlow input pipelines.

batch_size = 4
parsed_dataset = parsed_dataset.shuffle(40)
parsed_dataset = parsed_dataset.repeat(2)
parsed_dataset = parsed_dataset.batch(batch_size)
dbutils.fs.rm(path, True)
Out[8]: True