databricks-logo

    tws-topk-python

    (Python)
    Loading...

    TWS Python Top K

    2
    %sh
    pip install dbldatagen
    Collecting dbldatagen Using cached dbldatagen-0.4.0.post1-py3-none-any.whl.metadata (9.9 kB) Using cached dbldatagen-0.4.0.post1-py3-none-any.whl (122 kB) Installing collected packages: dbldatagen Successfully installed dbldatagen-0.4.0.post1 [notice] A new release of pip is available: 24.0 -> 25.0.1 [notice] To update, run: pip install --upgrade pip
    3
    import dbldatagen as dg
    import time
    import datetime
    from pyspark.sql.types import IntegerType, FloatType, StringType, TimestampType
    
    # Set table name
    table_name = f"synthetic_data_{int(time.time())}"
    print(f"table_name: {table_name}")
    
    # Generate session data with user_id, action_type, and timestamp
    data_rows = 100
    
    df_spec = (dg.DataGenerator(spark, name="session_data", rows=data_rows)
               .withColumn("user", StringType(), values=['user1', 'user2', 'user3', 'user4', 'user5'], random=True)
               .withColumn("value", IntegerType(), minValue=0, maxValue=200)
    )
    
                                
    df = df_spec.build()
    
    # Write to Delta table
    df.write.format("delta").mode("overwrite").saveAsTable(table_name)
    spark.conf.set(
      "spark.sql.streaming.stateStore.providerClass",
      "com.databricks.sql.streaming.state.RocksDBStateStoreProvider"
    )
    from pyspark.sql.functions import col
    print(table_name)
    df = spark.table(table_name)
    display(df)
    synthetic_data_1739563009

    Define output schema

    from pyspark.sql.types import StructType, StructField, StringType, IntegerType
    
    output_schema = StructType([
        StructField("rank", IntegerType(), False),
        StructField("user", StringType(), False),
        StructField("value", IntegerType(), False)
    ])
    
    

    Define stateful processor

    from typing import Iterator, Tuple
    import heapq
    import pandas as pd
    from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
    from pyspark.sql.types import StructType, StructField, IntegerType, StringType
    
    class TopKStatefulProcessor(StatefulProcessor):
    
        def init(self, handle: StatefulProcessorHandle):
            """
            Initializes the state handle for top-k elements.
            """
            state_schema = StructType([
                StructField("user", StringType(), False),
                StructField("value", IntegerType(), False)
            ])
            self.top_k_state = handle.getListState("topKState", state_schema)
            self.k = 5
    
        def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
            """
            Processes input rows, maintaining the top-k elements.
            """
            # Priority queue to store the top-k elements
            pq = []
    
            # Process incoming batch
            for pdf in rows:
                for _, row in pdf.iterrows():
                    name, value = row["user"], row["value"]
                    if len(pq) < self.k:
                        heapq.heappush(pq, (value, name))  # Store as (value, name) for ordering
                    else:
                        if value > pq[0][0]:  # Compare with the smallest in the heap
                            heapq.heappushpop(pq, (value, name))
    
            # Retrieve previous state
            exists = self.top_k_state.exists()
            old_top_k = list(self.top_k_state.get()) if exists else []
            old_top_k.sort(key=lambda x: -x[1])  # Sort in descending order
    
            # Merge current and previous top-k
            new_top_k = []
            pq = sorted(pq, reverse=True, key=lambda x: x[0])  # Sort by value in descending order
            i, j = 0, 0
    
            while len(new_top_k) < self.k and (i < len(pq) or j < len(old_top_k)):
                if i < len(pq) and (j >= len(old_top_k) or pq[i][0] >= old_top_k[j][1]):
                    new_top_k.append((pq[i][1], pq[i][0]))  # (name, value)
                    i += 1
                else:
                    new_top_k.append((old_top_k[j][0], old_top_k[j][1]))  # (name, value)
                    j += 1
    
            # Update state
            self.top_k_state.put(new_top_k)
    
            # Emit only changed top-k elements
            output_rows = []
            for rank, (name, value) in enumerate(new_top_k, start=1):
                if len(old_top_k) <= rank - 1 or old_top_k[rank - 1] != (name, value):
                    output_rows.append((rank, name, value))
            yield pd.DataFrame(output_rows, columns=["rank", "user", "value"])
        
        def close(self) -> None:
            pass
    import uuid
    
    base_path = f"/Workspace/Users/bo.gao@databricks.com/tws/{uuid.uuid4()}"
    checkpoint_dir = base_path + "/checkpoint"
    output_path = base_path + "/output"
    from pyspark.sql.functions import col
    
    q = spark \
        .readStream \
        .format("delta") \
        .option("maxFilesPerTrigger", "1") \
        .table(table_name) \
        .groupBy("user") \
        .transformWithStateInPandas( \
            statefulProcessor=TopKStatefulProcessor(), \
            outputStructType=output_schema, \
            outputMode="Update", \
            timeMode="None", \
        ) \
        .writeStream \
        .format("console") \
        .option("checkpointLocation", checkpoint_dir) \
        .outputMode("update") \
        .trigger(availableNow=True) \
        .start()
    2b2017a6-4fef-4c3b-9258-5a555312982a
    Last updated: 163 days ago

    (Optional) Drop the test table and delete output/ckp path

    spark.sql(f"DROP TABLE IF EXISTS {table_name}")
    DataFrame[]
    dbutils.fs.rm(base_path, True)
    True
    ;