Introduction to DataFrames - Scala

This notebook demonstrates a number of common Spark DataFrame functions using Scala.

Create DataFrames

// Create the case classes for our domain
case class Department(id: String, name: String)
case class Employee(firstName: String, lastName: String, email: String, salary: Int)
case class DepartmentWithEmployees(department: Department, employees: Seq[Employee])

// Create the Departments
val department1 = new Department("123456", "Computer Science")
val department2 = new Department("789012", "Mechanical Engineering")
val department3 = new Department("345678", "Theater and Drama")
val department4 = new Department("901234", "Indoor Recreation")

// Create the Employees
val employee1 = new Employee("michael", "armbrust", "no-reply@berkeley.edu", 100000)
val employee2 = new Employee("xiangrui", "meng", "no-reply@stanford.edu", 120000)
val employee3 = new Employee("matei", null, "no-reply@waterloo.edu", 140000)
val employee4 = new Employee(null, "wendell", "no-reply@princeton.edu", 160000)

// Create the DepartmentWithEmployees instances from Departments and Employees
val departmentWithEmployees1 = new DepartmentWithEmployees(department1, Seq(employee1, employee2))
val departmentWithEmployees2 = new DepartmentWithEmployees(department2, Seq(employee3, employee4))
val departmentWithEmployees3 = new DepartmentWithEmployees(department3, Seq(employee1, employee4))
val departmentWithEmployees4 = new DepartmentWithEmployees(department4, Seq(employee2, employee3))

Create the first DataFrame from a List of the Case Classes.

val departmentsWithEmployeesSeq1 = Seq(departmentWithEmployees1, departmentWithEmployees2)
val df1 = departmentsWithEmployeesSeq1.toDF()
display(df1)

Create a 2nd DataFrame from a List of Case Classes.

val departmentsWithEmployeesSeq2 = Seq(departmentWithEmployees3, departmentWithEmployees4)
val df2 = departmentsWithEmployeesSeq2.toDF()
display(df2)

Working with DataFrames

Union 2 DataFrames.

val unionDF = df1.unionAll(df2)
display(unionDF)

Write the Unioned DataFrame to a Parquet file.

// Remove the file if it exists
dbutils.fs.rm("/tmp/databricks-df-example.parquet", true)
unionDF.write.parquet("/tmp/databricks-df-example.parquet")

Read a DataFrame from the Parquet file.

val parquetDF = sqlContext.read.parquet("/tmp/databricks-df-example.parquet")
val explodeDF = parquetDF.explode($"employees") {
  case Row(employee: Seq[Row]) => employee.map{ employee =>
    val firstName = employee(0).asInstanceOf[String]
    val lastName = employee(1).asInstanceOf[String]
    val email = employee(2).asInstanceOf[String]
    val salary = employee(3).asInstanceOf[Int]
    Employee(firstName, lastName, email, salary)
  }
}.cache()
display(explodeDF)
explodeDF

Use ``filter()`` to return only the rows that match the given predicate.

val filterDF = explodeDF
  .filter($"firstName" === "xiangrui" || $"firstName" === "michael")
  .sort($"lastName".asc)
display(filterDF)

The ``where()`` clause is equivalent to ``filter()``.

val whereDF = explodeDF.where(($"firstName" === "xiangrui") || ($"firstName" === "michael")).sort($"lastName".asc)
display(whereDF)

Replace ``null`` values with ``–`` using DataFrame Na functions.

val naFunctions = explodeDF.na
val nonNullDF = naFunctions.fill("--")
display(nonNullDF)

Retrieve only rows with missing firstName or lastName.

val filterNonNullDF = nonNullDF.filter($"firstName" === "" || $"lastName" === "").sort($"email".asc)
display(filterNonNullDF)

Example aggregations using ``agg()`` and ``countDistinct()``.

import org.apache.spark.sql.functions._
// Find the distinct (firstName, lastName) combinations
val countDistinctDF = nonNullDF.select($"firstName", $"lastName")
  .groupBy($"firstName", $"lastName")
  .agg(countDistinct($"firstName") as "distinct_first_names")
display(countDistinctDF)

Compare the DataFrame and SQL Query Physical Plans (Hint: They should be the same.)

countDistinctDF.explain()
// register the DataFrame as a temp table so that we can query it using SQL
nonNullDF.registerTempTable("databricks_df_example")

// Perform the same query as the DataFrame above and return ``explain``
sqlContext.sql("""
SELECT firstName, lastName, count(distinct firstName) as distinct_first_names
FROM databricks_df_example
GROUP BY firstName, lastName
""").explain
// Sum up all the salaries
val salarySumDF = nonNullDF.agg("salary" -> "sum")
display(salarySumDF)

Print the summary statistics for the salaries.

nonNullDF.describe("salary").show()

Flattening

If your data has several levels of nesting, here is a helper function to flatten your DataFrame to make it easier to work with.

val veryNestedDF = Seq(("1", (2, (3, 4)))).toDF()
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

implicit class DataFrameFlattener(df: DataFrame) {
  def flattenSchema: DataFrame = {
    df.select(flatten(Nil, df.schema): _*)
  }

  protected def flatten(path: Seq[String], schema: DataType): Seq[Column] = schema match {
    case s: StructType => s.fields.flatMap(f => flatten(path :+ f.name, f.dataType))
    case other => col(path.map(n => s"`$n`").mkString(".")).as(path.mkString(".")) :: Nil
  }
}
display(veryNestedDF)
display(veryNestedDF.flattenSchema)

Cleanup: Remove the parquet file.

dbutils.fs.rm("/tmp/databricks-df-example.parquet", true)

DataFrame FAQs

This FAQ contains common use cases and example usage using the available APIs.

Q: How can I get better performance with DataFrame UDFs? A: If the functionality exists in the available built-in functions, using these will perform better. Example usage below.

We use the built-in functions and the withColumn() API to add new columns. We could have also used withColumnRenamed() to replace an existing column after the transformation. Note: Import the libraries in the first cell

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat

// Build an example DataFrame dataset to work with.
dbutils.fs.rm("/tmp/dataframe_sample.csv", true)
dbutils.fs.put("/tmp/dataframe_sample.csv", """
id|end_date|start_date|location
1|2015-10-14 00:00:00|2015-09-14 00:00:00|CA-SF
2|2015-10-15 01:00:20|2015-08-14 00:00:00|CA-SD
3|2015-10-16 02:30:00|2015-01-14 00:00:00|NY-NY
4|2015-10-17 03:00:20|2015-02-14 00:00:00|NY-NY
5|2015-10-18 04:30:00|2014-04-14 00:00:00|CA-LA
""", true)

val conf = new Configuration
conf.set("textinputformat.record.delimiter", "\n")
val rdd = sc.newAPIHadoopFile("/tmp/dataframe_sample.csv", classOf[TextInputFormat], classOf[LongWritable], classOf[Text], conf).map(_._2.toString).filter(_.nonEmpty)

val header = rdd.first()
// Parse the header line
val rdd_noheader = rdd.filter(x => !x.contains("id"))
// Convert the RDD[String] to an RDD[Rows]. Create an array using the delimiter and use Row.fromSeq()
val row_rdd = rdd_noheader.map(x => x.split('|')).map(x => Row.fromSeq(x))

val df_schema =
  StructType(
    header.split('|').map(fieldName => StructField(fieldName, StringType, true)))

var df = sqlContext.createDataFrame(row_rdd, df_schema)
df.printSchema
// Instead of registering a UDF, call the builtin functions to perform operations on the columns.
// This will provide a performance improvement as the builtins compile and run in the platform's JVM.

// Convert to a Date type
val timestamp2datetype: (Column) => Column = (x) => { to_date(x) }
df = df.withColumn("date", timestamp2datetype(col("end_date")))

// Parse out the date only
val timestamp2date: (Column) => Column = (x) => { regexp_replace(x," (\\d+)[:](\\d+)[:](\\d+).*$", "") }
df = df.withColumn("date_only", timestamp2date(col("end_date")))

// Split a string and index a field
val parse_city: (Column) => Column = (x) => { split(x, "-")(1) }
df = df.withColumn("city", parse_city(col("location")))

// Perform a date diff function
val dateDiff: (Column, Column) => Column = (x, y) => { datediff(to_date(y), to_date(x)) }
df = df.withColumn("date_diff", dateDiff(col("start_date"), col("end_date")))
df.registerTempTable("sample_df")
display(sql("select * from sample_df"))
Q: I want to convert the DataFrame back to json strings to send
back to Kafka.
A: There is an underlying toJSON() function that returns an
RDD of json strings using the column names and schema to produce the json records.
val rdd_json = df.toJSON
rdd_json.take(2).foreach(println)
Q: My UDF takes a parameter including the column to operate on.
How do I pass this parameter?
A: There is a function available called lit() that creates a
static column.
val add_n = udf((x: Integer, y: Integer) => x + y)

// We register a UDF that adds a column to the DataFrame, and we cast the id column to an Integer type.
df = df.withColumn("id_offset", add_n(lit(1000), col("id").cast("int")))
display(df)
val last_n_days = udf((x: Integer, y: Integer) => {
  if (x < y) true else false
})

//last_n_days = udf(lambda x, y: True if x < y else False, BooleanType())

val df_filtered = df.filter(last_n_days(col("date_diff"), lit(90)))
display(df_filtered)

Q: I have a table in the hive metastore and I’d like to access to table as a DataFrame. What’s the best way to define this? A: There’s multiple ways to define a DataFrame from a registered table. Syntax show below. Call table(tableName) or select and filter specific columns using an SQL query.

// Both return DataFrame types
val df_1 = table("sample_df")
val df_2 = sqlContext.sql("select * from sample_df")

Q: I’d like to clear all the cached tables on the current cluster. A: There’s an API available to do this at the global or per table level.

sqlContext.clearCache()
sqlContext.cacheTable("sample_df")
sqlContext.uncacheTable("sample_df")
Q: I’d like to compute aggregates on columns. What’s the best way
to do this?
A: There’s a new API available named agg(*exprs) that takes a
list of column names and expressions for the type of aggregation you’d like to compute. You can leverage the built-in functions mentioned above as part of the expressions for each column.
// Provide the min, count, and avg and groupBy the location column. Diplay the results
var agg_df = df.groupBy("location").agg(min("id"), count("id"), avg("date_diff"))
display(agg_df)

Q: I’d like to write out the DataFrames to Parquet, but would like to partition on a particular column. A: You can use the following APIs to accomplish this. Ensure the code does not create a large number of partitioned columns with the datasets otherwise the overhead of the metadata can cause significant slow downs. If there is a SQL table back by this directory, users will need to call refresh table _tableName_ to update the metadata prior to the query.

df = df.withColumn("end_month", month(col("end_date")))
df = df.withColumn("end_year", year(col("end_date")))
dbutils.fs.rm("/tmp/sample_table", true)
df.write.partitionBy("end_year", "end_month").parquet("/tmp/sample_table")
display(dbutils.fs.ls("/tmp/sample_table"))
Q: How do I properly handle cases where I want to filter out NULL
data?
A: You can use filter() and provide similar syntax as you would
with a SQL query.
val null_item_schema = StructType(Array(StructField("col1", StringType, true),
                               StructField("col2", IntegerType, true)))

val null_dataset = sc.parallelize(Array(("test", 1 ), (null, 2))).map(x => Row.fromTuple(x))
val null_df = sqlContext.createDataFrame(null_dataset, null_item_schema)
display(null_df.filter("col1 IS NOT NULL"))

Q: How do I infer the schema using the spark-csv or spark-avro libraries? A: Documented on the GitHub projects spark-csv, there is an inferSchema option flag. Providing a header would allow you to name the columns appropriately.

val adult_df = sqlContext.read.
    format("com.databricks.spark.csv").
    option("header", "false").
    option("inferSchema", "true").load("dbfs:/databricks-datasets/adult/adult.data")
adult_df.printSchema()

Q: You have a delimited string dataset that you want to convert to their datatypes. How would you accomplish this? A: Use the RDD APIs to filter out the malformed rows and map the values to the appropriate types.