User-Defined Functions - Scala

This topic contains Scala user-defined function (UDF) examples. It shows how to register UDFs, how to invoke UDFs, and caveats regarding evaluation order of subexpressions in Spark SQL.

Register a function as a UDF

val squared = (s: Long) => {
  s * s
spark.udf.register("square", squared)

Call the UDF in Spark SQL

spark.range(1, 20).registerTempTable("test")
%sql select id, square(id) as id_squared from test

Use UDF with DataFrames

import org.apache.spark.sql.functions.{col, udf}
val squared = udf((s: Long) => s * s)
display(spark.range(1, 20).select(squared(col("id")) as "id_squared"))

Evaluation order and null checking

Spark SQL (including SQL and the DataFrame and Dataset APIs) does not guarantee the order of evaluation of subexpressions. In particular, the inputs of an operator or function are not necessarily evaluated left-to-right or in any other fixed order. For example, logical AND and OR expressions do not have left-to-right “short-circuiting” semantics.

Therefore, it is dangerous to rely on the side effects or order of evaluation of Boolean expressions, and the order of WHERE and HAVING clauses, since such expressions and clauses can be reordered during query optimization and planning. Specifically, if a UDF relies on short-circuiting semantics in SQL for null checking, there’s no guarantee that the null check will happen before invoking the UDF. For example,

spark.udf.register("strlen", (s: String) => s.length)
spark.sql("select s from test1 where s is not null and strlen(s) > 1") // no guarantee

This WHERE clause does not guarantee the strlen UDF to be invoked after filtering out nulls.

To perform proper null checking, we recommend that you do either of the following:

  • Make the UDF itself null-aware and do null checking inside the UDF itself
  • Use IF or CASE WHEN expressions to do the null check and invoke the UDF in a conditional branch
spark.udf.register("strlen_nullsafe", (s: String) => if (s != null) s.length else -1)
spark.sql("select s from test1 where s is not null and strlen_nullsafe(s) > 1") // ok
spark.sql("select s from test1 where if(s is not null, strlen(s), null) > 1")   // ok