pandas UDFs Benchmark(Python)
Loading...
from pyspark.sql.types import *
from pyspark.sql.functions import col, count, rand, collect_list, explode, struct, count, lit
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = spark.range(0, 10 * 1000 * 1000).withColumn('id', (col('id') / 10000).cast('integer')).withColumn('v', rand())
df.cache()
df.count()

df.show()
+---+--------------------+ | id| v| +---+--------------------+ | 0| 0.6326195647822964| | 0| 0.5705850402990524| | 0| 0.49334879907662055| | 0| 0.5635969524407588| | 0| 0.38477148792102167| | 0| 0.6361652596893868| | 0| 0.4025726436821221| | 0| 0.14631056878370863| | 0| 0.5334328884655779| | 0| 0.8870495443933608| | 0|0.023552906495357684| | 0| 0.35656629907356774| | 0| 0.5605900724708613| | 0| 0.2634029747378819| | 0| 0.5764435676893086| | 0| 0.6047131453015696| | 0| 0.1397939403739349| | 0| 0.3181283680162754| | 0| 0.7090868803935365| | 0| 0.6005245220066551| +---+--------------------+ only showing top 20 rows
@udf('double')
def plus_one(v):
    return v + 1

%timeit df.withColumn('v', plus_one(df.v)).agg(count(col('v'))).show()
+--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ 1.84 s ± 125 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
import pandas as pd

@pandas_udf("double")
def pandas_plus_one(v: pd.Series) -> pd.Series:
    return v + 1

%timeit df.withColumn('v', pandas_plus_one(df.v)).agg(count(col('v'))).show()
+--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ 756 ms ± 155 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
from scipy import stats

@udf('double')
def cdf(v):
    return float(stats.norm.cdf(v))

%timeit df.withColumn('cumulative_probability', cdf(df.v)).agg(count(col('cumulative_probability'))).show()
+-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ 59.6 s ± 308 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
@pandas_udf('double')
def pandas_cdf(v: pd.Series) -> pd.Series:
    return pd.Series(stats.norm.cdf(v))

%timeit df.withColumn('cumulative_probability', pandas_cdf(df.v)).agg(count(col('cumulative_probability'))).show()
+-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ +-----------------------------+ |count(cumulative_probability)| +-----------------------------+ | 10000000| +-----------------------------+ 563 ms ± 39.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
from pyspark.sql import Row
@udf(ArrayType(df.schema))
def substract_mean(rows):
    vs = pd.Series([r.v for r in rows])
    vs = vs - vs.mean()
    return [Row(id=rows[i]['id'], v=float(vs[i])) for i in range(len(rows))]
  
%timeit df.groupby('id').agg(collect_list(struct(df['id'], df['v'])).alias('rows')).withColumn('new_rows', substract_mean(col('rows'))).withColumn('new_row', explode(col('new_rows'))).withColumn('id', col('new_row.id')).withColumn('v', col('new_row.v')).agg(count(col('v'))).show()
+--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ 19.8 s ± 205 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Input/output are both a pandas.DataFrame
def subtract_mean(pdf: pd.DataFrame) -> pd.DataFrame:
    return pdf.assign(v=pdf.v - pdf.v.mean())
%timeit df.groupby('id').applyInPandas(subtract_mean, schema=df.schema).agg(count(col('v'))).show()
+--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ +--------+ |count(v)| +--------+ |10000000| +--------+ 851 ms ± 267 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
df2 = df.withColumn('y', rand()).withColumn('x1', rand()).withColumn('x2', rand()).select('id', 'y', 'x1', 'x2')
df2.show()                                                               
+---+-------------------+-------------------+-------------------+ | id| y| x1| x2| +---+-------------------+-------------------+-------------------+ | 0| 0.2338989825976887| 0.633100356232247| 0.954204207316112| | 0|0.39643430204857677| 0.85327309907234| 0.9902342558728575| | 0|0.17447714315818363| 0.8197016619581589| 0.5162565383520107| | 0|0.26201359386979706| 0.9697244317874232| 0.3846076783577862| | 0|0.13252729098740157|0.35380200859011635| 0.5394304491052249| | 0| 0.3953409146437262|0.22169498632444917| 0.2413738208106757| | 0|0.02847281582088046| 0.6844951181765015| 0.607506212470167| | 0| 0.6188430279226395| 0.4340005113564953|0.34583147592632846| | 0|0.16264598956069098| 0.5187656080461458| 0.979809149550738| | 0|0.11850804931961945|0.06792844257524921| 0.6214232392465295| | 0| 0.5750950045985798|0.04595390617943351| 0.5410821689447765| | 0| 0.1591564414070844| 0.9368110503352995| 0.3053825787455162| | 0|0.09591450306219962| 0.9134240769969283|0.46703661119289264| | 0| 0.8732691655199151| 0.5099224986589732| 0.2040555895362003| | 0|0.18420019882623195| 0.2418074843488408|0.42018130234144213| | 0| 0.9432720398706194| 0.5931652354142246|0.16260194070689293| | 0|0.15698558582788047|0.31082568486780826| 0.9857279139360818| | 0| 0.3427373421179152|0.34542445559598867|0.34650686093198035| | 0|0.21114688992002228| 0.3006537686503924| 0.789695843517491| | 0| 0.410882892406374|0.49847035154438035|0.45442013837363326| +---+-------------------+-------------------+-------------------+ only showing top 20 rows
import statsmodels.api as sm
# df has four columns: id, y, x1, x2
group_column = 'id'
y_column = 'y'
x_columns = ['x1', 'x2']
schema = df2.select(group_column, *x_columns).schema
# Input/output are both a pandas.DataFrame
def ols(pdf: pd.DataFrame) -> pd.DataFrame:
    group_key = pdf[group_column].iloc[0]
    y = pdf[y_column]
    X = pdf[x_columns]
    X = sm.add_constant(X)
    model = sm.OLS(y, X).fit()
    return pd.DataFrame([[group_key] + [model.params[i] for i in   x_columns]], columns=[group_column] + x_columns)
beta = df2.groupby(group_column).applyInPandas(ols, schema=schema)
beta.show()
/databricks/python/lib/python3.7/site-packages/patsy/constraint.py:13: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working from collections import Mapping +---+--------------------+--------------------+ | id| x1| x2| +---+--------------------+--------------------+ |148|0.009689691500611265|-0.01587461704374...| |463|-0.00980345607239...|3.142477274721590...| |471|0.003937155730384494|0.002362506756603669| |496|-0.00921190434451...|-1.55048547909590...| |833|-5.21659000331312...|0.013393277837858339| |243|0.005085286958833874|-0.00541790924603...| |392|-0.01149809820803...|-0.01082029976248...| |540|-0.02360728612551367|-3.12892925093739...| |623|-0.01113770785829715|0.021956473747411803| |737|0.012284202668742914|-0.00832707548783...| |858| 0.03256189048540876|-0.00868613959592...| |897|0.003303829387931...|0.005943048530431379| | 31|-0.00195294404301...|0.005327872686633158| |516|0.009795829776227867|0.002080688514128...| | 85|-0.00355796369830...|-1.15381255834101...| |137|-0.01279334574691...|0.008977364386420276| |251|0.006849934659350...|0.004550445865878831| |451|-0.00536587959423...|-0.02222484590602...| |580| 0.01135045833304452|0.003433770839062957| |808| 0.0060610609732694|-0.01012065408145...| +---+--------------------+--------------------+ only showing top 20 rows