import org.apache.commons.io.FileUtils import org.apache.spark.sql.{ DataFrame, Row } import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
// Declare DataFrame data val testRows: Array[Row] = Array( new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")) ) // DataFrame schema val schema = StructType(List(StructField("id", IntegerType), StructField("IntegerTypeLabel", IntegerType), StructField("LongTypeLabel", LongType), StructField("FloatTypeLabel", FloatType), StructField("DoubleTypeLabel", DoubleType), StructField("VectorLabel", ArrayType(DoubleType, true)), StructField("name", StringType))) // Create DataFrame val rdd = spark.sparkContext.parallelize(testRows) val df: DataFrame = spark.createDataFrame(rdd, schema)
testRows: Array[org.apache.spark.sql.Row] = Array([11,1,23,10.0,14.0,List(1.0, 2.0),r1], [21,2,24,12.0,15.0,List(2.0, 2.0),r2])
schema: org.apache.spark.sql.types.StructType = StructType(StructField(id,IntegerType,true), StructField(IntegerTypeLabel,IntegerType,true), StructField(LongTypeLabel,LongType,true), StructField(FloatTypeLabel,FloatType,true), StructField(DoubleTypeLabel,DoubleType,true), StructField(VectorLabel,ArrayType(DoubleType,true),true), StructField(name,StringType,true))
rdd: org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] = ParallelCollectionRDD[105] at parallelize at command-2002623810216418:16
df: org.apache.spark.sql.DataFrame = [id: int, IntegerTypeLabel: int ... 5 more fields]
df.show()
+---+----------------+-------------+--------------+---------------+-----------+----+
| id|IntegerTypeLabel|LongTypeLabel|FloatTypeLabel|DoubleTypeLabel|VectorLabel|name|
+---+----------------+-------------+--------------+---------------+-----------+----+
| 11| 1| 23| 10.0| 14.0| [1.0, 2.0]| r1|
| 21| 2| 24| 12.0| 15.0| [2.0, 2.0]| r2|
+---+----------------+-------------+--------------+---------------+-----------+----+
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided. val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path) importedDf1.show() //Read TFRecords into DataFrame using custom schema val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path) importedDf2.show()
+-----------+---------------+----+-------------+----------------+--------------+---+
|VectorLabel|DoubleTypeLabel|name|LongTypeLabel|IntegerTypeLabel|FloatTypeLabel| id|
+-----------+---------------+----+-------------+----------------+--------------+---+
| [1.0, 2.0]| 14.0| r1| 23| 1| 10.0| 11|
| [2.0, 2.0]| 15.0| r2| 24| 2| 12.0| 21|
+-----------+---------------+----+-------------+----------------+--------------+---+
+---+----------------+-------------+--------------+---------------+-----------+----+
| id|IntegerTypeLabel|LongTypeLabel|FloatTypeLabel|DoubleTypeLabel|VectorLabel|name|
+---+----------------+-------------+--------------+---------------+-----------+----+
| 11| 1| 23| 10.0| 14.0| [1.0, 2.0]| r1|
| 21| 2| 24| 12.0| 15.0| [2.0, 2.0]| r2|
+---+----------------+-------------+--------------+---------------+-----------+----+
importedDf1: org.apache.spark.sql.DataFrame = [VectorLabel: array<float>, DoubleTypeLabel: float ... 5 more fields]
importedDf2: org.apache.spark.sql.DataFrame = [id: int, IntegerTypeLabel: int ... 5 more fields]
%sh curl -s http://us.data.yt8m.org/2/video/train/trainIc.tfrecord > /dbfs/tmp/dl/spark-tf-connector/video_level-train-0.tfrecord
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
100 4425k 100 4425k 0 0 7242k 0 --:--:-- --:--:-- --:--:-- 7231k
//Import Video-level Example dataset into DataFrame val videoSchema = StructType(List(StructField("id", StringType), StructField("labels", ArrayType(IntegerType, true)), StructField("mean_rgb", ArrayType(FloatType, true)), StructField("mean_audio", ArrayType(FloatType, true)))) val videoDf: DataFrame = spark.read.format("tfrecords").schema(videoSchema).option("recordType", "Example") .load("dbfs:/tmp/dl/spark-tf-connector/video_level-train-0.tfrecord") videoDf.show(5)
+----+--------------------+--------------------+--------------------+
| id| labels| mean_rgb| mean_audio|
+----+--------------------+--------------------+--------------------+
|rjIc|[2, 26, 45, 130, ...|[0.80833536, 0.73...|[-0.47192606, -0....|
|SUIc| [1343]|[-0.65343785, 0.7...|[0.41343987, 1.44...|
|MjIc| [2, 45, 212, 1745]|[-0.17644894, 1.0...|[-1.3482137, 0.72...|
|rzIc|[11, 20, 22, 29, ...|[0.06379097, 0.74...|[-0.35289493, -0....|
|zXIc| [0, 1, 828]|[0.21688479, -1.2...|[-0.90491796, -0....|
+----+--------------------+--------------------+--------------------+
only showing top 5 rows
videoSchema: org.apache.spark.sql.types.StructType = StructType(StructField(id,StringType,true), StructField(labels,ArrayType(IntegerType,true),true), StructField(mean_rgb,ArrayType(FloatType,true),true), StructField(mean_audio,ArrayType(FloatType,true),true))
videoDf: org.apache.spark.sql.DataFrame = [id: string, labels: array<int> ... 2 more fields]
// Import data back into a DataFrame, verify that it matches the original val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").schema(videoSchema).load("dbfs:/tmp/dl/spark-tf-connector/youtube-8m-video.tfrecords") importedDf1.show(5)
+----+--------------------+--------------------+--------------------+
| id| labels| mean_rgb| mean_audio|
+----+--------------------+--------------------+--------------------+
|rjIc|[2, 26, 45, 130, ...|[0.80833536, 0.73...|[-0.47192606, -0....|
|SUIc| [1343]|[-0.65343785, 0.7...|[0.41343987, 1.44...|
|MjIc| [2, 45, 212, 1745]|[-0.17644894, 1.0...|[-1.3482137, 0.72...|
|rzIc|[11, 20, 22, 29, ...|[0.06379097, 0.74...|[-0.35289493, -0....|
|zXIc| [0, 1, 828]|[0.21688479, -1.2...|[-0.90491796, -0....|
+----+--------------------+--------------------+--------------------+
only showing top 5 rows
importedDf1: org.apache.spark.sql.DataFrame = [id: string, labels: array<int> ... 2 more fields]
Spark-Tensorflow-Connector Example
This notebook (adapted from the spark-tensorflow-connector usage examples) demonstrates exporting Spark DataFrames to TFRecords and loading the exported TFRecords back into DataFrames.
Last refresh: Never