%md # TWS Python Top K
TWS Python Top K
2
%sh pip install dbldatagen
Collecting dbldatagen
Using cached dbldatagen-0.4.0.post1-py3-none-any.whl.metadata (9.9 kB)
Using cached dbldatagen-0.4.0.post1-py3-none-any.whl (122 kB)
Installing collected packages: dbldatagen
Successfully installed dbldatagen-0.4.0.post1
[notice] A new release of pip is available: 24.0 -> 25.0.1
[notice] To update, run: pip install --upgrade pip
3
import dbldatagen as dg import time import datetime from pyspark.sql.types import IntegerType, FloatType, StringType, TimestampType # Set table name table_name = f"synthetic_data_{int(time.time())}" print(f"table_name: {table_name}") # Generate session data with user_id, action_type, and timestamp data_rows = 100 df_spec = (dg.DataGenerator(spark, name="session_data", rows=data_rows) .withColumn("user", StringType(), values=['user1', 'user2', 'user3', 'user4', 'user5'], random=True) .withColumn("value", IntegerType(), minValue=0, maxValue=200) ) df = df_spec.build() # Write to Delta table df.write.format("delta").mode("overwrite").saveAsTable(table_name)
table_name: synthetic_data_1739563009
spark.conf.set( "spark.sql.streaming.stateStore.providerClass", "com.databricks.sql.streaming.state.RocksDBStateStoreProvider" ) from pyspark.sql.functions import col print(table_name) df = spark.table(table_name) display(df)
synthetic_data_1739563009
Table
To pick up a draggable item, press the space bar.
While dragging, use the arrow keys to move the item.
Press space again to drop the item in its new position, or press escape to cancel.
%md Define output schema
Define output schema
from pyspark.sql.types import StructType, StructField, StringType, IntegerType output_schema = StructType([ StructField("rank", IntegerType(), False), StructField("user", StringType(), False), StructField("value", IntegerType(), False) ])
%md Define stateful processor
Define stateful processor
from typing import Iterator, Tuple import heapq import pandas as pd from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle from pyspark.sql.types import StructType, StructField, IntegerType, StringType class TopKStatefulProcessor(StatefulProcessor): def init(self, handle: StatefulProcessorHandle): """ Initializes the state handle for top-k elements. """ state_schema = StructType([ StructField("user", StringType(), False), StructField("value", IntegerType(), False) ]) self.top_k_state = handle.getListState("topKState", state_schema) self.k = 5 def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: """ Processes input rows, maintaining the top-k elements. """ # Priority queue to store the top-k elements pq = [] # Process incoming batch for pdf in rows: for _, row in pdf.iterrows(): name, value = row["user"], row["value"] if len(pq) < self.k: heapq.heappush(pq, (value, name)) # Store as (value, name) for ordering else: if value > pq[0][0]: # Compare with the smallest in the heap heapq.heappushpop(pq, (value, name)) # Retrieve previous state exists = self.top_k_state.exists() old_top_k = list(self.top_k_state.get()) if exists else [] old_top_k.sort(key=lambda x: -x[1]) # Sort in descending order # Merge current and previous top-k new_top_k = [] pq = sorted(pq, reverse=True, key=lambda x: x[0]) # Sort by value in descending order i, j = 0, 0 while len(new_top_k) < self.k and (i < len(pq) or j < len(old_top_k)): if i < len(pq) and (j >= len(old_top_k) or pq[i][0] >= old_top_k[j][1]): new_top_k.append((pq[i][1], pq[i][0])) # (name, value) i += 1 else: new_top_k.append((old_top_k[j][0], old_top_k[j][1])) # (name, value) j += 1 # Update state self.top_k_state.put(new_top_k) # Emit only changed top-k elements output_rows = [] for rank, (name, value) in enumerate(new_top_k, start=1): if len(old_top_k) <= rank - 1 or old_top_k[rank - 1] != (name, value): output_rows.append((rank, name, value)) yield pd.DataFrame(output_rows, columns=["rank", "user", "value"]) def close(self) -> None: pass
import uuid base_path = f"/Workspace/Users/bo.gao@databricks.com/tws/{uuid.uuid4()}" checkpoint_dir = base_path + "/checkpoint" output_path = base_path + "/output"
from pyspark.sql.functions import col q = spark \ .readStream \ .format("delta") \ .option("maxFilesPerTrigger", "1") \ .table(table_name) \ .groupBy("user") \ .transformWithStateInPandas( \ statefulProcessor=TopKStatefulProcessor(), \ outputStructType=output_schema, \ outputMode="Update", \ timeMode="None", \ ) \ .writeStream \ .format("console") \ .option("checkpointLocation", checkpoint_dir) \ .outputMode("update") \ .trigger(availableNow=True) \ .start()
2b2017a6-4fef-4c3b-9258-5a555312982a
Last updated: 163 days ago
%md (Optional) Drop the test table and delete output/ckp path
(Optional) Drop the test table and delete output/ckp path
spark.sql(f"DROP TABLE IF EXISTS {table_name}")
DataFrame[]
dbutils.fs.rm(base_path, True)
True