df = spark.range(0, 10 * 1000 * 1000).withColumn('id', (col('id') / 10000).cast('integer')).withColumn('v', rand()) df.cache() df.count() df.show()
+---+--------------------+
| id| v|
+---+--------------------+
| 0| 0.6326195647822964|
| 0| 0.5705850402990524|
| 0| 0.49334879907662055|
| 0| 0.5635969524407588|
| 0| 0.38477148792102167|
| 0| 0.6361652596893868|
| 0| 0.4025726436821221|
| 0| 0.14631056878370863|
| 0| 0.5334328884655779|
| 0| 0.8870495443933608|
| 0|0.023552906495357684|
| 0| 0.35656629907356774|
| 0| 0.5605900724708613|
| 0| 0.2634029747378819|
| 0| 0.5764435676893086|
| 0| 0.6047131453015696|
| 0| 0.1397939403739349|
| 0| 0.3181283680162754|
| 0| 0.7090868803935365|
| 0| 0.6005245220066551|
+---+--------------------+
only showing top 20 rows
'double') def plus_one(v): return v + 1 %timeit df.withColumn('v', plus_one(df.v)).agg(count(col('v'))).show()(
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
1.84 s ± 125 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
import pandas as pd ("double") def pandas_plus_one(v: pd.Series) -> pd.Series: return v + 1 %timeit df.withColumn('v', pandas_plus_one(df.v)).agg(count(col('v'))).show()
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
756 ms ± 155 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
from scipy import stats ('double') def cdf(v): return float(stats.norm.cdf(v)) %timeit df.withColumn('cumulative_probability', cdf(df.v)).agg(count(col('cumulative_probability'))).show()
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
59.6 s ± 308 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
'double') def pandas_cdf(v: pd.Series) -> pd.Series: return pd.Series(stats.norm.cdf(v)) %timeit df.withColumn('cumulative_probability', pandas_cdf(df.v)).agg(count(col('cumulative_probability'))).show()(
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
| 10000000|
+-----------------------------+
563 ms ± 39.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
from pyspark.sql import Row (ArrayType(df.schema)) def substract_mean(rows): vs = pd.Series([r.v for r in rows]) vs = vs - vs.mean() return [Row(id=rows[i]['id'], v=float(vs[i])) for i in range(len(rows))] %timeit df.groupby('id').agg(collect_list(struct(df['id'], df['v'])).alias('rows')).withColumn('new_rows', substract_mean(col('rows'))).withColumn('new_row', explode(col('new_rows'))).withColumn('id', col('new_row.id')).withColumn('v', col('new_row.v')).agg(count(col('v'))).show()
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
19.8 s ± 205 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Input/output are both a pandas.DataFrame def subtract_mean(pdf: pd.DataFrame) -> pd.DataFrame: return pdf.assign(v=pdf.v - pdf.v.mean()) %timeit df.groupby('id').applyInPandas(subtract_mean, schema=df.schema).agg(count(col('v'))).show()
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
+--------+
|count(v)|
+--------+
|10000000|
+--------+
851 ms ± 267 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
df2 = df.withColumn('y', rand()).withColumn('x1', rand()).withColumn('x2', rand()).select('id', 'y', 'x1', 'x2') df2.show()
+---+-------------------+-------------------+-------------------+
| id| y| x1| x2|
+---+-------------------+-------------------+-------------------+
| 0| 0.2338989825976887| 0.633100356232247| 0.954204207316112|
| 0|0.39643430204857677| 0.85327309907234| 0.9902342558728575|
| 0|0.17447714315818363| 0.8197016619581589| 0.5162565383520107|
| 0|0.26201359386979706| 0.9697244317874232| 0.3846076783577862|
| 0|0.13252729098740157|0.35380200859011635| 0.5394304491052249|
| 0| 0.3953409146437262|0.22169498632444917| 0.2413738208106757|
| 0|0.02847281582088046| 0.6844951181765015| 0.607506212470167|
| 0| 0.6188430279226395| 0.4340005113564953|0.34583147592632846|
| 0|0.16264598956069098| 0.5187656080461458| 0.979809149550738|
| 0|0.11850804931961945|0.06792844257524921| 0.6214232392465295|
| 0| 0.5750950045985798|0.04595390617943351| 0.5410821689447765|
| 0| 0.1591564414070844| 0.9368110503352995| 0.3053825787455162|
| 0|0.09591450306219962| 0.9134240769969283|0.46703661119289264|
| 0| 0.8732691655199151| 0.5099224986589732| 0.2040555895362003|
| 0|0.18420019882623195| 0.2418074843488408|0.42018130234144213|
| 0| 0.9432720398706194| 0.5931652354142246|0.16260194070689293|
| 0|0.15698558582788047|0.31082568486780826| 0.9857279139360818|
| 0| 0.3427373421179152|0.34542445559598867|0.34650686093198035|
| 0|0.21114688992002228| 0.3006537686503924| 0.789695843517491|
| 0| 0.410882892406374|0.49847035154438035|0.45442013837363326|
+---+-------------------+-------------------+-------------------+
only showing top 20 rows
import statsmodels.api as sm # df has four columns: id, y, x1, x2 group_column = 'id' y_column = 'y' x_columns = ['x1', 'x2'] schema = df2.select(group_column, *x_columns).schema # Input/output are both a pandas.DataFrame def ols(pdf: pd.DataFrame) -> pd.DataFrame: group_key = pdf[group_column].iloc[0] y = pdf[y_column] X = pdf[x_columns] X = sm.add_constant(X) model = sm.OLS(y, X).fit() return pd.DataFrame([[group_key] + [model.params[i] for i in x_columns]], columns=[group_column] + x_columns) beta = df2.groupby(group_column).applyInPandas(ols, schema=schema) beta.show()
/databricks/python/lib/python3.7/site-packages/patsy/constraint.py:13: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working
from collections import Mapping
+---+--------------------+--------------------+
| id| x1| x2|
+---+--------------------+--------------------+
|148|0.009689691500611265|-0.01587461704374...|
|463|-0.00980345607239...|3.142477274721590...|
|471|0.003937155730384494|0.002362506756603669|
|496|-0.00921190434451...|-1.55048547909590...|
|833|-5.21659000331312...|0.013393277837858339|
|243|0.005085286958833874|-0.00541790924603...|
|392|-0.01149809820803...|-0.01082029976248...|
|540|-0.02360728612551367|-3.12892925093739...|
|623|-0.01113770785829715|0.021956473747411803|
|737|0.012284202668742914|-0.00832707548783...|
|858| 0.03256189048540876|-0.00868613959592...|
|897|0.003303829387931...|0.005943048530431379|
| 31|-0.00195294404301...|0.005327872686633158|
|516|0.009795829776227867|0.002080688514128...|
| 85|-0.00355796369830...|-1.15381255834101...|
|137|-0.01279334574691...|0.008977364386420276|
|251|0.006849934659350...|0.004550445865878831|
|451|-0.00536587959423...|-0.02222484590602...|
|580| 0.01135045833304452|0.003433770839062957|
|808| 0.0060610609732694|-0.01012065408145...|
+---+--------------------+--------------------+
only showing top 20 rows