Simplify Chained Transformations

Sometimes you may need to perform multiple transformations on your DataFrame:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame

val testDf = (1 to 10).toDF("col")

def func0(x: Int => Int, y: Int)(in: DataFrame): DataFrame = {
  in.filter('col > x(y))
}
def func1(x: Int)(in: DataFrame): DataFrame = {
  in.selectExpr("col", s"col + $x as col1")
}
def func2(add: Int)(in: DataFrame): DataFrame = {
  in.withColumn("col2", expr(s"col1 + $add"))
}

When you apply these transformations, you may end up with spaghetti code like this:

def inc(i: Int) = i + 1

val tmp0 = func0(inc, 3)(testDf)
val tmp1 = func1(1)(tmp0)
val tmp2 = func2(2)(tmp1)
val res = tmp2.withColumn("col3", expr("col2 + 3"))

This topic describes several methods to simplify chained transformations.

DataFrame transform API

To benefit from the functional programming style in Spark, you can leverage the DataFrame transform API, for example:

val res = testDf.transform(func0(inc, 4))
                .transform(func1(1))
                .transform(func2(2))
                .withColumn("col3", expr("col2 + 3"))

Function.chain API

To go even further, you can leverage the Scala Function library, to chain the transformations, for example:

val chained = Function.chain(List(func0(inc, 4)(_), func1(1)(_), func2(2)(_)))
val res = testDf.transform(chained)
                .withColumn("col3", expr("col2 + 3"))

implicit class

Another alternative is to define a Scala implicit class, which allows you to eliminate the DataFrame transform API:

implicit class MyTransforms(df: DataFrame) {
    def func0(x: Int => Int, y: Int): DataFrame = {
        df.filter('col > x(y))
    }
    def func1(x: Int): DataFrame = {
        df.selectExpr("col", s"col + $x as col1")
    }
    def func2(add: Int): DataFrame = {
        df.withColumn("col2", expr(s"col1 + $add"))
    }
}

Then you can call the functions directly:

val res = testDf.func0(inc, 1)
            .func1(2)
            .func2(3)
            .withColumn("col3", expr("col2 + 3"))