Skip to main content

pandas_udf

Creates a pandas user defined function.

Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data, which allows pandas operations. A Pandas UDF is defined using the pandas_udf as a decorator or to wrap the function, and no additional configuration is required. A Pandas UDF behaves as a regular PySpark function API in general.

Syntax

Python
import pyspark.sql.functions as sf

# As a decorator
@sf.pandas_udf(returnType=<returnType>, functionType=<functionType>)
def function_name(col):
# function body
pass

# As a function wrapper
sf.pandas_udf(f=<function>, returnType=<returnType>, functionType=<functionType>)

Parameters

Parameter

Type

Description

f

function

Optional. User-defined function. A python function if used as a standalone function.

returnType

pyspark.sql.types.DataType or str

Optional. The return type of the user-defined function. The value can be either a DataType object or a DDL-formatted type string.

functionType

int

Optional. An enum value in PandasUDFType. Default: SCALAR. This parameter exists for compatibility. Using Python type hints is encouraged.

Examples

Example 1: Series to Series - Convert strings to uppercase.

Python
import pandas as pd
from pyspark.sql.functions import pandas_udf

@pandas_udf("string")
def to_upper(s: pd.Series) -> pd.Series:
return s.str.upper()

df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(to_upper("name")).show()
Output
+--------------+
|to_upper(name)|
+--------------+
| JOHN DOE|
+--------------+

Example 2: Series to Series with keyword arguments.

Python
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import IntegerType
from pyspark.sql import functions as sf

@pandas_udf(returnType=IntegerType())
def calc(a: pd.Series, b: pd.Series) -> pd.Series:
return a + 10 * b

spark.range(2).select(calc(b=sf.col("id") * 10, a=sf.col("id"))).show()
Output
+-----------------------------+
|calc(b => (id * 10), a => id)|
+-----------------------------+
| 0|
| 101|
+-----------------------------+

Example 3: Iterator of Series to Iterator of Series.

Python
import pandas as pd
from typing import Iterator
from pyspark.sql.functions import pandas_udf

@pandas_udf("long")
def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
for s in iterator:
yield s + 1

df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"]))
df.select(plus_one(df.v)).show()
Output
+-----------+
|plus_one(v)|
+-----------+
| 2|
| 3|
| 4|
+-----------+

Example 4: Series to Scalar - Grouped aggregation.

Python
import pandas as pd
from pyspark.sql.functions import pandas_udf

@pandas_udf("double")
def mean_udf(v: pd.Series) -> float:
return v.mean()

df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
df.groupby("id").agg(mean_udf(df['v'])).show()
Output
+---+-----------+
| id|mean_udf(v)|
+---+-----------+
| 1| 1.5|
| 2| 6.0|
+---+-----------+

Example 5: Series to Scalar with window functions.

Python
import pandas as pd
from pyspark.sql import Window
from pyspark.sql.functions import pandas_udf

@pandas_udf("double")
def mean_udf(v: pd.Series) -> float:
return v.mean()

df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
w = Window.partitionBy('id').orderBy('v').rowsBetween(-1, 0)
df.withColumn('mean_v', mean_udf("v").over(w)).show()
Output
+---+----+------+
| id| v|mean_v|
+---+----+------+
| 1| 1.0| 1.0|
| 1| 2.0| 1.5|
| 2| 3.0| 3.0|
| 2| 5.0| 4.0|
| 2|10.0| 7.5|
+---+----+------+

Example 6: Iterator of Series to Scalar - Memory-efficient grouped aggregation.

Python
import pandas as pd
from typing import Iterator
from pyspark.sql.functions import pandas_udf

@pandas_udf("double")
def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
sum_val = 0.0
cnt = 0
for v in it:
sum_val += v.sum()
cnt += len(v)
return sum_val / cnt

df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
df.groupby("id").agg(pandas_mean_iter(df['v'])).show()
Output
+---+-------------------+
| id|pandas_mean_iter(v)|
+---+-------------------+
| 1| 1.5|
| 2| 6.0|
+---+-------------------+