spark-tensorflow-connector(Scala)

Spark-Tensorflow-Connector Example

This notebook (adapted from the spark-tensorflow-connector usage examples) demonstrates exporting Spark DataFrames to TFRecords and loading the exported TFRecords back into DataFrames.

import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
import org.apache.commons.io.FileUtils import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._

Create DataFrame

Construct a DataFrame with columns of various (int, long, float, array, string) types

// Declare DataFrame data
val testRows: Array[Row] = Array(
  new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
  new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))
)

// DataFrame schema
val schema = StructType(List(StructField("id", IntegerType), 
                             StructField("IntegerTypeLabel", IntegerType),
                             StructField("LongTypeLabel", LongType),
                             StructField("FloatTypeLabel", FloatType),
                             StructField("DoubleTypeLabel", DoubleType),
                             StructField("VectorLabel", ArrayType(DoubleType, true)),
                             StructField("name", StringType)))
// Create DataFrame
val rdd = spark.sparkContext.parallelize(testRows)
val df: DataFrame = spark.createDataFrame(rdd, schema)
testRows: Array[org.apache.spark.sql.Row] = Array([11,1,23,10.0,14.0,List(1.0, 2.0),r1], [21,2,24,12.0,15.0,List(2.0, 2.0),r2]) schema: org.apache.spark.sql.types.StructType = StructType(StructField(id,IntegerType,true), StructField(IntegerTypeLabel,IntegerType,true), StructField(LongTypeLabel,LongType,true), StructField(FloatTypeLabel,FloatType,true), StructField(DoubleTypeLabel,DoubleType,true), StructField(VectorLabel,ArrayType(DoubleType,true),true), StructField(name,StringType,true)) rdd: org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] = ParallelCollectionRDD[105] at parallelize at command-2002623810216418:16 df: org.apache.spark.sql.DataFrame = [id: int, IntegerTypeLabel: int ... 5 more fields]
df.show()
+---+----------------+-------------+--------------+---------------+-----------+----+ | id|IntegerTypeLabel|LongTypeLabel|FloatTypeLabel|DoubleTypeLabel|VectorLabel|name| +---+----------------+-------------+--------------+---------------+-----------+----+ | 11| 1| 23| 10.0| 14.0| [1.0, 2.0]| r1| | 21| 2| 24| 12.0| 15.0| [2.0, 2.0]| r2| +---+----------------+-------------+--------------+---------------+-----------+----+

Export DataFrame to TFRecords

WARNING: The command below will overwrite existing data

val path = "/tmp/dl/spark-tf-connector/test-output.tfrecord"
df.write.format("tfrecords").option("recordType", "Example").mode("overwrite").save(path)
path: String = /tmp/dl/spark-tf-connector/test-output.tfrecord

Read exported TFRecords back into a DataFrame

Note that the imported DataFrame matches the original (compare df.show() and importedDf1.show())

//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path)
importedDf1.show()

//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()
+-----------+---------------+----+-------------+----------------+--------------+---+ |VectorLabel|DoubleTypeLabel|name|LongTypeLabel|IntegerTypeLabel|FloatTypeLabel| id| +-----------+---------------+----+-------------+----------------+--------------+---+ | [1.0, 2.0]| 14.0| r1| 23| 1| 10.0| 11| | [2.0, 2.0]| 15.0| r2| 24| 2| 12.0| 21| +-----------+---------------+----+-------------+----------------+--------------+---+ +---+----------------+-------------+--------------+---------------+-----------+----+ | id|IntegerTypeLabel|LongTypeLabel|FloatTypeLabel|DoubleTypeLabel|VectorLabel|name| +---+----------------+-------------+--------------+---------------+-----------+----+ | 11| 1| 23| 10.0| 14.0| [1.0, 2.0]| r1| | 21| 2| 24| 12.0| 15.0| [2.0, 2.0]| r2| +---+----------------+-------------+--------------+---------------+-----------+----+ importedDf1: org.apache.spark.sql.DataFrame = [VectorLabel: array<float>, DoubleTypeLabel: float ... 5 more fields] importedDf2: org.apache.spark.sql.DataFrame = [id: int, IntegerTypeLabel: int ... 5 more fields]

Loading an existing TFRecord dataset into Spark

The example below loads the YouTube-8M dataset into a DataFrame. First, we download the dataset to DBFS:

%sh
curl -s http://us.data.yt8m.org/2/video/train/trainIc.tfrecord > /dbfs/tmp/dl/spark-tf-connector/video_level-train-0.tfrecord
% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0 100 4425k 100 4425k 0 0 7242k 0 --:--:-- --:--:-- --:--:-- 7231k

Declare schema and import data into a DataFrame

//Import Video-level Example dataset into DataFrame
val videoSchema = StructType(List(StructField("id", StringType),
                             StructField("labels", ArrayType(IntegerType, true)),
                             StructField("mean_rgb", ArrayType(FloatType, true)),
                             StructField("mean_audio", ArrayType(FloatType, true))))
val videoDf: DataFrame = spark.read.format("tfrecords").schema(videoSchema).option("recordType", "Example")
  .load("dbfs:/tmp/dl/spark-tf-connector/video_level-train-0.tfrecord")
videoDf.show(5)
+----+--------------------+--------------------+--------------------+ | id| labels| mean_rgb| mean_audio| +----+--------------------+--------------------+--------------------+ |rjIc|[2, 26, 45, 130, ...|[0.80833536, 0.73...|[-0.47192606, -0....| |SUIc| [1343]|[-0.65343785, 0.7...|[0.41343987, 1.44...| |MjIc| [2, 45, 212, 1745]|[-0.17644894, 1.0...|[-1.3482137, 0.72...| |rzIc|[11, 20, 22, 29, ...|[0.06379097, 0.74...|[-0.35289493, -0....| |zXIc| [0, 1, 828]|[0.21688479, -1.2...|[-0.90491796, -0....| +----+--------------------+--------------------+--------------------+ only showing top 5 rows videoSchema: org.apache.spark.sql.types.StructType = StructType(StructField(id,StringType,true), StructField(labels,ArrayType(IntegerType,true),true), StructField(mean_rgb,ArrayType(FloatType,true),true), StructField(mean_audio,ArrayType(FloatType,true),true)) videoDf: org.apache.spark.sql.DataFrame = [id: string, labels: array<int> ... 2 more fields]

Export data to TFRecords and import it back into a DataFrame

Note that the imported DataFrame (importedDf1) matches the original (videoDf).

// Write DataFrame to a tfrecords file
// WARNING: This command will overwrite existing data
videoDf.write.format("tfrecords").option("recordType", "Example").mode("overwrite").save("dbfs:/tmp/dl/spark-tf-connector/youtube-8m-video.tfrecords")
// Import data back into a DataFrame, verify that it matches the original
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").schema(videoSchema).load("dbfs:/tmp/dl/spark-tf-connector/youtube-8m-video.tfrecords")
importedDf1.show(5)
+----+--------------------+--------------------+--------------------+ | id| labels| mean_rgb| mean_audio| +----+--------------------+--------------------+--------------------+ |rjIc|[2, 26, 45, 130, ...|[0.80833536, 0.73...|[-0.47192606, -0....| |SUIc| [1343]|[-0.65343785, 0.7...|[0.41343987, 1.44...| |MjIc| [2, 45, 212, 1745]|[-0.17644894, 1.0...|[-1.3482137, 0.72...| |rzIc|[11, 20, 22, 29, ...|[0.06379097, 0.74...|[-0.35289493, -0....| |zXIc| [0, 1, 828]|[0.21688479, -1.2...|[-0.90491796, -0....| +----+--------------------+--------------------+--------------------+ only showing top 5 rows importedDf1: org.apache.spark.sql.DataFrame = [id: string, labels: array<int> ... 2 more fields]

Remove downloaded data files

%fs rm -r /tmp/dl/spark-tf-connector/
res10: Boolean = true