Introduction to DataFrames - Scala

This article demonstrates a number of common Spark DataFrame functions using Scala.

Create DataFrames

// Create the case classes for our domain
case class Department(id: String, name: String)
case class Employee(firstName: String, lastName: String, email: String, salary: Int)
case class DepartmentWithEmployees(department: Department, employees: Seq[Employee])

// Create the Departments
val department1 = new Department("123456", "Computer Science")
val department2 = new Department("789012", "Mechanical Engineering")
val department3 = new Department("345678", "Theater and Drama")
val department4 = new Department("901234", "Indoor Recreation")

// Create the Employees
val employee1 = new Employee("michael", "armbrust", "", 100000)
val employee2 = new Employee("xiangrui", "meng", "", 120000)
val employee3 = new Employee("matei", null, "", 140000)
val employee4 = new Employee(null, "wendell", "", 160000)
val employee5 = new Employee("michael", "jackson", "no-reply@neverla.nd", 80000)

// Create the DepartmentWithEmployees instances from Departments and Employees
val departmentWithEmployees1 = new DepartmentWithEmployees(department1, Seq(employee1, employee2))
val departmentWithEmployees2 = new DepartmentWithEmployees(department2, Seq(employee3, employee4))
val departmentWithEmployees3 = new DepartmentWithEmployees(department3, Seq(employee5, employee4))
val departmentWithEmployees4 = new DepartmentWithEmployees(department4, Seq(employee2, employee3))

Create DataFrames from a list of the case classes

val departmentsWithEmployeesSeq1 = Seq(departmentWithEmployees1, departmentWithEmployees2)
val df1 = departmentsWithEmployeesSeq1.toDF()

val departmentsWithEmployeesSeq2 = Seq(departmentWithEmployees3, departmentWithEmployees4)
val df2 = departmentsWithEmployeesSeq2.toDF()

Work with DataFrames

Union two DataFrames

val unionDF = df1.union(df2)

Write the unioned DataFrame to a Parquet file

// Remove the file if it exists
dbutils.fs.rm("/tmp/databricks-df-example.parquet", true)

Read a DataFrame from the Parquet file

val parquetDF ="/tmp/databricks-df-example.parquet")

Explode the employees column

import org.apache.spark.sql.functions._

val explodeDF =$"employees"))

Flatten the fields of the employee class into columns

val flattenDF =$"col.*")
|firstName|lastName|               email|salary|
|    matei|    null|no-reply@waterloo...|140000|
|     null| wendell|no-reply@princeto...|160000|
|  michael|armbrust|no-reply@berkeley...|100000|
| xiangrui|    meng|no-reply@stanford...|120000|
|  michael| jackson| no-reply@neverla.nd| 80000|
|     null| wendell|no-reply@princeto...|160000|
| xiangrui|    meng|no-reply@stanford...|120000|
|    matei|    null|no-reply@waterloo...|140000|

Use filter() to return the rows that match a predicate

val filterDF = flattenDF
  .filter($"firstName" === "xiangrui" || $"firstName" === "michael")

The where() clause is equivalent to filter()

val whereDF = flattenDF
  .where($"firstName" === "xiangrui" || $"firstName" === "michael")

Replace null values with -- using DataFrame Na function

val nonNullDF ="--")

Retrieve rows with missing firstName or lastName

val filterNonNullDF = nonNullDF.filter($"firstName" === "--" || $"lastName" === "--").sort($"email".asc)

Example aggregations using agg() and countDistinct()

// Find the distinct last names for each first name
val countDistinctDF =$"firstName", $"lastName")
  .agg(countDistinct($"lastName") as "distinct_last_names")

Compare the DataFrame and SQL query physical plans


They should be the same.

// register the DataFrame as a temp view so that we can query it using SQL

  SELECT firstName, count(distinct lastName) as distinct_last_names
  FROM databricks_df_example
  GROUP BY firstName

Sum up all the salaries

val salarySumDF = nonNullDF.agg("salary" -> "sum")

Cleanup: remove the Parquet file

dbutils.fs.rm("/tmp/databricks-df-example.parquet", true)

Frequently asked questions (FAQ)

This FAQ addresses common use cases and example usage using the available APIs. For more detailed API descriptions, see the DataFrameReader and DataFrameWriter documentation.

How can I get better performance with DataFrame UDFs?

If the functionality exists in the available built-in functions, using these will perform better.

We use the built-in functions and the withColumn() API to add new columns. We could have also used withColumnRenamed() to replace an existing column after the transformation.

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat

// Build an example DataFrame dataset to work with.
dbutils.fs.rm("/tmp/dataframe_sample.csv", true)
dbutils.fs.put("/tmp/dataframe_sample.csv", """
1|2015-10-14 00:00:00|2015-09-14 00:00:00|CA-SF
2|2015-10-15 01:00:20|2015-08-14 00:00:00|CA-SD
3|2015-10-16 02:30:00|2015-01-14 00:00:00|NY-NY
4|2015-10-17 03:00:20|2015-02-14 00:00:00|NY-NY
5|2015-10-18 04:30:00|2014-04-14 00:00:00|CA-LA
""", true)

val conf = new Configuration
conf.set("textinputformat.record.delimiter", "\n")
val rdd = sc.newAPIHadoopFile("/tmp/dataframe_sample.csv", classOf[TextInputFormat], classOf[LongWritable], classOf[Text], conf).map(_._2.toString).filter(_.nonEmpty)

val header = rdd.first()
// Parse the header line
val rdd_noheader = rdd.filter(x => !x.contains("id"))
// Convert the RDD[String] to an RDD[Rows]. Create an array using the delimiter and use Row.fromSeq()
val row_rdd = => x.split('|')).map(x => Row.fromSeq(x))

val df_schema =
    header.split('|').map(fieldName => StructField(fieldName, StringType, true)))

var df = spark.createDataFrame(row_rdd, df_schema)
// Instead of registering a UDF, call the builtin functions to perform operations on the columns.
// This will provide a performance improvement as the builtins compile and run in the platform's JVM.

// Convert to a Date type
val timestamp2datetype: (Column) => Column = (x) => { to_date(x) }
df = df.withColumn("date", timestamp2datetype(col("end_date")))

// Parse out the date only
val timestamp2date: (Column) => Column = (x) => { regexp_replace(x," (\\d+)[:](\\d+)[:](\\d+).*$", "") }
df = df.withColumn("date_only", timestamp2date(col("end_date")))

// Split a string and index a field
val parse_city: (Column) => Column = (x) => { split(x, "-")(1) }
df = df.withColumn("city", parse_city(col("location")))

// Perform a date diff function
val dateDiff: (Column, Column) => Column = (x, y) => { datediff(to_date(y), to_date(x)) }
df = df.withColumn("date_diff", dateDiff(col("start_date"), col("end_date")))
display(sql("select * from sample_df"))

I want to convert the DataFrame back to JSON strings to send back to Kafka.

There is a toJSON() function that returns an RDD of JSON strings using the column names and schema to produce the JSON records.

val rdd_json = df.toJSON

My UDF takes a parameter including the column to operate on. How do I pass this parameter?

There is a function available called lit() that creates a static column.

val add_n = udf((x: Integer, y: Integer) => x + y)

// We register a UDF that adds a column to the DataFrame, and we cast the id column to an Integer type.
df = df.withColumn("id_offset", add_n(lit(1000), col("id").cast("int")))
val last_n_days = udf((x: Integer, y: Integer) => {
  if (x < y) true else false

//last_n_days = udf(lambda x, y: True if x < y else False, BooleanType())

val df_filtered = df.filter(last_n_days(col("date_diff"), lit(90)))

I have a table in the Hive metastore and I’d like to access to table as a DataFrame. What’s the best way to define this?

There are multiple ways to define a DataFrame from a registered table. Call table(tableName) or select and filter specific columns using an SQL query:

// Both return DataFrame types
val df_1 = table("sample_df")
val df_2 = spark.sql("select * from sample_df")

I’d like to clear all the cached tables on the current cluster.

There’s an API available to do this at the global or per table level.


I’d like to compute aggregates on columns. What’s the best way to do this?

There’s an API named agg(*exprs) that takes a list of column names and expressions for the type of aggregation you’d like to compute. You can leverage the built-in functions mentioned above as part of the expressions for each column.

// Provide the min, count, and avg and groupBy the location column. Diplay the results
var agg_df = df.groupBy("location").agg(min("id"), count("id"), avg("date_diff"))

I’d like to write out the DataFrames to Parquet, but would like to partition on a particular column.

You can use the following APIs to accomplish this. Ensure the code does not create a large number of partitioned columns with the datasets otherwise the overhead of the metadata can cause significant slow downs. If there is a SQL table back by this directory, you will need to call refresh table <table-name> to update the metadata prior to the query.

df = df.withColumn("end_month", month(col("end_date")))
df = df.withColumn("end_year", year(col("end_date")))
dbutils.fs.rm("/tmp/sample_table", true)
df.write.partitionBy("end_year", "end_month").parquet("/tmp/sample_table")

How do I properly handle cases where I want to filter out NULL data?

You can use filter() and provide similar syntax as you would with a SQL query.

val null_item_schema = StructType(Array(StructField("col1", StringType, true),
                               StructField("col2", IntegerType, true)))

val null_dataset = sc.parallelize(Array(("test", 1 ), (null, 2))).map(x => Row.fromTuple(x))
val null_df = spark.createDataFrame(null_dataset, null_item_schema)
display(null_df.filter("col1 IS NOT NULL"))

How do I infer the schema using the csv or spark-avro libraries?

There is an inferSchema option flag. Providing a header allows you to name the columns appropriately.

val adult_df =
    option("header", "false").
    option("inferSchema", "true").load("dbfs:/databricks-datasets/adult/")

You have a delimited string dataset that you want to convert to their data types. How would you accomplish this?

Use the RDD APIs to filter out the malformed rows and map the values to the appropriate types.