Introduction to DataFrames - Python

Creating DataFrames with Python

# import pyspark class Row from module sql
from pyspark.sql import *

# Create Example Data - Departments and Employees

# Create the Departments
department1 = Row(id='123456', name='Computer Science')
department2 = Row(id='789012', name='Mechanical Engineering')
department3 = Row(id='345678', name='Theater and Drama')
department4 = Row(id='901234', name='Indoor Recreation')

# Create the Employees
Employee = Row("firstName", "lastName", "email", "salary")
employee1 = Employee('michael', 'armbrust', 'no-reply@berkeley.edu', 100000)
employee2 = Employee('xiangrui', 'meng', 'no-reply@stanford.edu', 120000)
employee3 = Employee('matei', None, 'no-reply@waterloo.edu', 140000)
employee4 = Employee(None, 'wendell', 'no-reply@berkeley.edu', 160000)

# Create the DepartmentWithEmployees instances from Departments and Employees
departmentWithEmployees1 = Row(department=department1, employees=[employee1, employee2])
departmentWithEmployees2 = Row(department=department2, employees=[employee3, employee4])
departmentWithEmployees3 = Row(department=department3, employees=[employee1, employee4])
departmentWithEmployees4 = Row(department=department4, employees=[employee2, employee3])

print department1
print employee2
print departmentWithEmployees1.employees[0].email

Create the first DataFrame from a list of the rows.

departmentsWithEmployeesSeq1 = [departmentWithEmployees1, departmentWithEmployees2]
df1 = sqlContext.createDataFrame(departmentsWithEmployeesSeq1)

display(df1)

Create a second DataFrame from a list of rows.

departmentsWithEmployeesSeq2 = [departmentWithEmployees3, departmentWithEmployees4]
df2 = sqlContext.createDataFrame(departmentsWithEmployeesSeq2)

display(df2)

Working with DataFrames

Union 2 DataFrames.

unionDF = df1.unionAll(df2)
display(unionDF)

Write the Unioned DataFrame to a Parquet file.

# Remove the file if it exists
dbutils.fs.rm("/tmp/databricks-df-example.parquet", True)
unionDF.write.parquet("/tmp/databricks-df-example.parquet")

Read a DataFrame from the Parquet file.

parquetDF = sqlContext.read.parquet("/tmp/databricks-df-example.parquet")
display(parquetDF)

Explode the employees column.

>>> from pyspark.sql import Row
>>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
>>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
[Row(anInt=1), Row(anInt=2), Row(anInt=3)]

>>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
+---+-----+
|key|value|
+---+-----+
|  a|    b|
+---+-----+
"""
from pyspark.sql.functions import explode

df = parquetDF.select(explode("employees").alias("e"))
explodeDF = df.selectExpr("e.firstName", "e.lastName", "e.email", "e.salary")

display(explodeDF)
explodeDF

Use filter() to return only the rows that match the given predicate.

filterDF = explodeDF.filter(explodeDF.firstName == "xiangrui").sort(explodeDF.lastName)
display(filterDF)
from pyspark.sql.functions import col, asc

# Use `|` instead of `or`
filterDF = explodeDF.filter((col("firstName") == "xiangrui") | (col("firstName") == "michael")).sort(asc("lastName"))
display(filterDF)

The where() clause is equivalent to filter().

whereDF = explodeDF.where((col("firstName") == "xiangrui") | (col("firstName") == "michael")).sort(asc("lastName"))
display(whereDF)

Replace null values with -- using DataFrame Na functions.

nonNullDF = explodeDF.fillna("--")
display(nonNullDF)

Retrieve only rows with missing firstName or lastName.

filterNonNullDF = explodeDF.filter(col("firstName").isNull() | col("lastName").isNull()).sort("email")
display(filterNonNullDF)

Example aggregations using agg() and countDistinct().

from pyspark.sql.functions import countDistinct

countDistinctDF = explodeDF.select("firstName", "lastName")\
  .groupBy("firstName", "lastName")\
  .agg(countDistinct("firstName"))

display(countDistinctDF)

Compare the DataFrame and SQL Query Physical Plans (Hint: They should be the same.)

countDistinctDF.explain()
# register the DataFrame as a temp table so that we can query it using SQL
explodeDF.registerTempTable("databricks_df_example")

# Perform the same query as the DataFrame above and return ``explain``
countDistinctDF_sql = sqlContext.sql("SELECT firstName, lastName, count(distinct firstName) as distinct_first_names FROM databricks_df_example GROUP BY firstName, lastName")

countDistinctDF_sql.explain()

Sum up all the salaries

salarySumDF = explodeDF.agg({"salary" : "sum"})
display(salarySumDF)
type(explodeDF.salary)

Print the summary statistics for the salaries.

explodeDF.describe("salary").show()

An example using Pandas & Matplotlib Integration

import pandas as pd
import matplotlib.pyplot as plt
plt.clf()
pdDF = nonNullDF.toPandas()
pdDF.plot(x='firstName', y='salary', kind='bar', rot=45)
display()

Cleanup: Remove the parquet file.

dbutils.fs.rm("/tmp/databricks-df-example.parquet", True)

DataFrame FAQs

This FAQ addresses common use cases and example usage using the available APIs. See the PySpark Documentation for more detail API descriptions here.

Q: How can I get better performance with DataFrame UDFs? A: If the functionality exists in the available built-in functions, using these will perform better. Example usage below. Specific documentation is here. We use the built-in functions and the withColumn() API to add new columns. We could have also used withColumnRenamed() to replace an existing column after the transformation.

from pyspark.sql import functions as F
from pyspark.sql.types import *

# Build an example DataFrame dataset to work with.
dbutils.fs.rm("/tmp/dataframe_sample.csv", True)
dbutils.fs.put("/tmp/dataframe_sample.csv", """id|end_date|start_date|location
1|2015-10-14 00:00:00|2015-09-14 00:00:00|CA-SF
2|2015-10-15 01:00:20|2015-08-14 00:00:00|CA-SD
3|2015-10-16 02:30:00|2015-01-14 00:00:00|NY-NY
4|2015-10-17 03:00:20|2015-02-14 00:00:00|NY-NY
5|2015-10-18 04:30:00|2014-04-14 00:00:00|CA-SD
""", True)

formatPackage = "csv" if sc.version > '1.6' else "com.databricks.spark.csv"
df = sqlContext.read.format(formatPackage).options(header='true', delimiter = '|').load("/tmp/dataframe_sample.csv")
df.printSchema()
# Instead of registering a UDF, call the builtin functions to perform operations on the columns.
# This will provide a performance improvement as the builtins compile and run in the platform's JVM.

# Convert to a Date type
df = df.withColumn('date', F.to_date(df.end_date))

# Parse out the date only
df = df.withColumn('date_only', F.regexp_replace(df.end_date,' (\d+)[:](\d+)[:](\d+).*$', ''))

# Split a string and index a field
df = df.withColumn('city', F.split(df.location, '-')[1])

# Perform a date diff function
df = df.withColumn('date_diff', F.datediff(F.to_date(df.end_date), F.to_date(df.start_date)))
df.registerTempTable("sample_df")
display(sql("select * from sample_df"))
Q: I want to convert the DataFrame back to json strings to send back to Kafka.
A: There is an underlying toJSON() function that returns an RDD of json strings using the column names and schema to produce the json records.
rdd_json = df.toJSON()
rdd_json.take(2)
Q: My UDF takes a parameter including the column to operate on. How do I pass this parameter?
A: There is a function available called lit() that creates a constant column.
from pyspark.sql import functions as F

add_n = udf(lambda x, y: x + y, IntegerType())

# We register a UDF that adds a column to the DataFrame, and we cast the id column to an Integer type.
df = df.withColumn('id_offset', add_n(F.lit(1000), df.id.cast(IntegerType())))
display(df)
# any constants used by UDF will automatically pass through to workers
N = 90
last_n_days = udf(lambda x: x < N, BooleanType())

df_filtered = df.filter(last_n_days(df.date_diff))
display(df_filtered)
Q: I have a table in the hive metastore and I’d like to access to table as a DataFrame. What’s the best way to define this?
A: There’s multiple ways to define a DataFrame from a registered table. Syntax show below.
Call table(tableName) or select and filter specific columns using an SQL query.
# Both return DataFrame types
df_1 = table("sample_df")
df_2 = sqlContext.sql("select * from sample_df")
Q: I’d like to clear all the cached tables on the current cluster.
A: There’s an API available to do this at a global level or per table.
sqlContext.clearCache()
sqlContext.cacheTable("sample_df")
sqlContext.uncacheTable("sample_df")
Q: I’d like to compute aggregates on columns. What’s the best way to do this?
A: There’s an API available named agg(*exprs) that takes a list of column names and expressions for the type of aggregation you’d like to compute. Documentation is available here. You can leverage the built-in functions that mentioned above as part of the expressions for each column.
# Provide the min, count, and avg and groupBy the location column. Diplay the results
agg_df = df.groupBy("location").agg(F.min("id"), F.count("id"), F.avg("date_diff"))
display(agg_df)

Q: I’d like to write out the DataFrames to Parquet, but would like to partition on a particular column. A: You can use the following APIs to accomplish this. Ensure the code does not create a large number of partition columns with the datasets otherwise the overhead of the metadata can cause significant slow downs. If there is a SQL table back by this directory, users will need to call refresh table _tableName_ to update the metadata prior to the query.

df = df.withColumn('end_month', F.month('end_date'))
df = df.withColumn('end_year', F.year('end_date'))
df.write.partitionBy("end_year", "end_month").parquet("/tmp/sample_table")
display(dbutils.fs.ls("/tmp/sample_table"))
Q: How do I properly handle cases where I want to filter out NULL data?
A: You can use filter() and provide similar syntax as you would with a SQL query.
null_item_schema = StructType([StructField("col1", StringType(), True),
                               StructField("col2", IntegerType(), True)])
null_df = sqlContext.createDataFrame([("test", 1), (None, 2)], null_item_schema)
display(null_df.filter("col1 IS NOT NULL"))
Q: How do I infer the schema using the spark-csv or spark-avro libraries?
A: Documented on the GitHub projects spark-csv, there is an inferSchema option flag. Providing a header ensures columns appropriate column naming.
adult_df = sqlContext.read.\
    format("com.databricks.spark.csv").\
    option("header", "false").\
    option("inferSchema", "true").load("dbfs:/databricks-datasets/adult/adult.data")
adult_df.printSchema()
Q: You have a delimited string dataset that you want to convert to their datatypes. How would you accomplish this?
A: Use the RDD APIs to filter out the malformed rows and map the values to the appropriate types. We define a function that filters the items using regular expressions.