Introduction to Datasets

The Datasets API, added in Apache Spark 1.6, provides the benefits of RDDs (strong typing, ability to use powerful lambda functions) with the benefits of Spark SQL’s optimized execution engine that leverages project Tungsten. A user can define a Dataset JVM objects and then manipulate them using functional transformations (map, flatMap, filter, etc.) similar to an RDD. The benefits is that, unlike RDDs, these transformations are now applied on a structured and strongly typed distributed collection that allows Spark to leverage Spark SQL’s execution engine for optimization.

Verify that the notebook runs an Apache Spark 1.6+ version as the Dataset API requires version 1.6+.

require(sc.version.replace(".", "").substring(0,3).toInt >= 160, "Spark 1.6.0+ required to run this notebook.")

Creating Datasets

You can simply call .toDS() on a sequence to convert the sequence to a Dataset.

val dataset = Seq(1, 2, 3).toDS()

If you have a sequence of case classes, calling .toDS() will provide a dataset with all the necessary fields in the dataset.

case class Person(name: String, age: Int)

val personDS = Seq(Person("Max", 33), Person("Adam", 32), Person("Muller", 62)).toDS()

Creating Datasets from a RDD You can call rdd.toDS() to convert an RDD into a Dataset.

val rdd = sc.parallelize(Seq((1, "Spark"), (2, "Databricks")))
val integerDS = rdd.toDS()

Creating Datasets from a DataFrame You can call[SomeCaseClass] to convert the DataFrame to a Dataset.

case class Company(name: String, foundingYear: Int, numEmployees: Int)
val inputSeq = Seq(Company("ABC", 1998, 310), Company("XYZ", 1983, 904), Company("NOP", 2005, 83))
val df = sc.parallelize(inputSeq).toDF()

val companyDS =[Company]

You can also deal with tuples while converting a DataFrame to Dataset without using a case class

val rdd = sc.parallelize(Seq((1, "Spark"), (2, "Databricks"), (3, "Notebook")))
val df = rdd.toDF("Id", "Name")

val dataset =[(Int, String)]

Working with Datasets

Word Count Example

val wordsDataset = sc.parallelize(Seq("Spark I am your father", "May the spark be with you", "Spark I am your father")).toDS()
val groupedDataset = wordsDataset.flatMap(_.toLowerCase.split(" "))
                                 .filter(_ != "")
val countsDataset = groupedDataset.count()

Joining in Datasets

The following example demonstrates the following: Union multiple datasets Doing an inner join on a condition Group by a specific column Doing a custom aggregation (average) on the grouped dataset.

The examples uses only Datasets API to demonstrate all the operations available. In reality, using DataFrames for doing aggregation would be simpler and faster than doing custom aggregation with mapGroups. The next section covers the details of converting datasets to DataFrames and using DataFrames API for doing aggregations.

case class Employee(name: String, age: Int, departmentId: Int, salary: Double)
case class Department(id: Int, name: String)

case class Record(name: String, age: Int, salary: Double, departmentId: Int, departmentName: String)
case class ResultSet(departmentId: Int, departmentName: String, avgSalary: Double)

val employeeDataSet1 = sc.parallelize(Seq(Employee("Max", 22, 1, 100000.0), Employee("Adam", 33, 2, 93000.0), Employee("Eve", 35, 2, 89999.0), Employee("Muller", 39, 3, 120000.0))).toDS()
val employeeDataSet2 = sc.parallelize(Seq(Employee("John", 26, 1, 990000.0), Employee("Joe", 38, 3, 115000.0))).toDS()
val departmentDataSet = sc.parallelize(Seq(Department(1, "Engineering"), Department(2, "Marketing"), Department(3, "Sales"))).toDS()

val employeeDataset = employeeDataSet1.union(employeeDataSet2)

def averageSalary(key: (Int, String), iterator: Iterator[Record]): ResultSet = {
   val (total, count) = iterator.foldLeft(0.0, 0.0) {
       case ((total, count), x) => (total + x.salary, count + 1)
   ResultSet(key._1, key._2, total/count)

val averageSalaryDataset = employeeDataset.joinWith(departmentDataSet, $"departmentId" === $"id", "inner")
                                           .map(record => Record(, record._1.age, record._1.salary, record._1.departmentId,
                                           .filter(record => record.age > 25)
                                           .groupBy($"departmentId", $"departmentName")

Converting Datasets to DataFrames

The above 2 examples dealt with using pure Datasets APIs. You can also easily move from Datasets to DataFrames and leverage the DataFrames APIs. The below example shows the word count example that uses both Datasets and DataFrames APIs.

import org.apache.spark.sql.functions._

val wordsDataset = sc.parallelize(Seq("Spark I am your father", "May the spark be with you", "Spark I am your father")).toDS()
val result = wordsDataset
               .flatMap(_.split(" "))               // Split on whitespace
               .filter(_ != "")                     // Filter empty words
               .toDF()                              // Convert to DataFrame to perform aggregation / sorting
               .groupBy($"value")                   // Count number of occurences of each word
               .agg(count("*") as "numOccurances")
               .orderBy($"numOccurances" desc)      // Show most common words first