databricks-logo

    tws-session-python

    (Python)
    Loading...

    TWS session tracking in Python

    2
    %sh
    pip install dbldatagen
    Collecting dbldatagen Using cached dbldatagen-0.4.0.post1-py3-none-any.whl.metadata (9.9 kB) Using cached dbldatagen-0.4.0.post1-py3-none-any.whl (122 kB) Installing collected packages: dbldatagen Successfully installed dbldatagen-0.4.0.post1 [notice] A new release of pip is available: 24.0 -> 25.0.1 [notice] To update, run: pip install --upgrade pip
    3
    spark.conf.set(
      "spark.sql.streaming.stateStore.providerClass",
      "com.databricks.sql.streaming.state.RocksDBStateStoreProvider"
    )
    4
    import dbldatagen as dg
    from pyspark.sql.types import *
    from datetime import datetime, timedelta
    import pandas as pd
    from pyspark.sql import Row
    from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
    from typing import Iterator
    import uuid
    import time
    
    # Generate synthetic session data
    def generate_test_data(spark):
        data_rows = 1000 * 100
        base_time = datetime.now()
        
        df_spec = (dg.DataGenerator(spark, name="session_data", rows=data_rows)
            .withColumn("user_id", StringType(), 
                values=['user1', 'user2', 'user3', 'user4', 'user5'])
            .withColumn("action_type", StringType(), 
                values=['login', 'page_view', 'purchase'],
                random=True, weights=[1, 5, 1])
            .withColumn("timestamp", TimestampType(),
                begin=base_time - timedelta(hours=1),
                end=base_time,
                random=True)
            .withColumn("session_value", IntegerType(), 
                minValue=1, maxValue=100)
        )
        
        return df_spec.build()
    # Define Session Tracking Processor
    class SessionTrackingProcessor(StatefulProcessor):
        def init(self, handle: StatefulProcessorHandle) -> None:
            # Define schema for session state
            state_schema = StructType([
                StructField("last_timestamp", LongType(), True),
                StructField("login_count", IntegerType(), True),
                StructField("page_view_count", IntegerType(), True),
                StructField("purchase_count", IntegerType(), True),
                StructField("session_value", LongType(), True)
            ])
            
            self.session_state = handle.getValueState("session_state", state_schema)
            self.handle = handle
            # Session timeout in milliseconds (30 minutes)
            self.session_timeout_ms = 30 * 60 * 1000
            
        def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
            # Get current state or initialize if not exists
            if self.session_state.exists():
                current_state = self.session_state.get()
                state_dict = {
                    'last_timestamp': current_state[0],
                    'login_count': current_state[1],
                    'page_view_count': current_state[2],
                    'purchase_count': current_state[3],
                    'session_value': current_state[4]
                }
            else:
                state_dict = {
                    'last_timestamp': 0,
                    'login_count': 0,
                    'page_view_count': 0,
                    'purchase_count': 0,
                    'session_value': 0
                }
                
            results = []
            for pdf in rows:
                # Process each row in the batch
                for _, row in pdf.iterrows():
                    timestamp = int(row['timestamp'].timestamp() * 1000)  # Convert to milliseconds
                    action_type = row['action_type']
                    value = row['session_value']
                    
                    # Update counts based on action type
                    if action_type == 'login':
                        state_dict['login_count'] += 1
                    elif action_type == 'page_view':
                        state_dict['page_view_count'] += 1
                    elif action_type == 'purchase':
                        state_dict['purchase_count'] += 1
                        
                    state_dict['session_value'] += value
                    state_dict['last_timestamp'] = timestamp
                    
                    # Clear old timer if exists
                    if state_dict['last_timestamp'] > 0:
                        old_timer = state_dict['last_timestamp'] + self.session_timeout_ms
                        self.handle.deleteTimer(old_timer)
                    
                    # Register new timer
                    new_timer = timestamp + self.session_timeout_ms
                    self.handle.registerTimer(new_timer)
                    
                    # Create result row
                    results.append({
                        'user_id': key[0],
                        'status': 'ACTIVE',
                        'session_value': state_dict['session_value'],
                        'login_count': state_dict['login_count'],
                        'page_view_count': state_dict['page_view_count'],
                        'purchase_count': state_dict['purchase_count']
                    })
            
            # Update state
            self.session_state.update((
                state_dict['last_timestamp'],
                state_dict['login_count'],
                state_dict['page_view_count'],
                state_dict['purchase_count'],
                state_dict['session_value']
            ))
            
            return iter([pd.DataFrame(results)])
        
        def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
            # Get final state before clearing
            if self.session_state.exists():
                final_state = self.session_state.get()
                result = pd.DataFrame([{
                    'user_id': key[0],
                    'status': 'EXPIRED',
                    'session_value': final_state[4],
                    'login_count': final_state[1],
                    'page_view_count': final_state[2],
                    'purchase_count': final_state[3]
                }])
            else:
                result = pd.DataFrame([{
                    'user_id': key[0],
                    'status': 'EXPIRED',
                    'session_value': 0,
                    'login_count': 0,
                    'page_view_count': 0,
                    'purchase_count': 0
                }])
                
            # Clear session state
            self.session_state.clear()
            return iter([result])
            
        def close(self) -> None:
            pass
    
    # Define output schema
    output_schema = StructType([
        StructField("user_id", StringType(), True),
        StructField("status", StringType(), True),
        StructField("session_value", LongType(), True),
        StructField("login_count", IntegerType(), True),
        StructField("page_view_count", IntegerType(), True),
        StructField("purchase_count", IntegerType(), True)
    ])
    
    
    # Generate and save test data
    table_name = f"session_data_{int(time.time())}"
    df = generate_test_data(spark)
    df.write.format("delta").mode("overwrite").saveAsTable(table_name)
    
    # Set up paths
    checkpoint_dir = f"/tmp/session_checkpoint_{uuid.uuid4()}"
    output_path = f"/tmp/session_output_{uuid.uuid4()}"
    
    # Create and start query
    query = (spark.readStream
        .format("delta")
        .table(table_name)
        .groupBy("user_id")
        .transformWithStateInPandas(
            statefulProcessor=SessionTrackingProcessor(),
            outputStructType=output_schema,
            outputMode="Append",
            timeMode="EventTime"
        )
        .writeStream
        .format("delta")
        .outputMode("append")
        .option("checkpointLocation", checkpoint_dir)
        .start(output_path)
    )
    
    # Wait for processing to complete
    query.processAllAvailable()
    query.stop()
    spark.read.format("delta").load(output_path).show()
    +-------+------+-------------+-----------+---------------+--------------+ |user_id|status|session_value|login_count|page_view_count|purchase_count| +-------+------+-------------+-----------+---------------+--------------+ | user1|ACTIVE| 1| 0| 0| 1| | user1|ACTIVE| 7| 0| 1| 1| | user1|ACTIVE| 18| 0| 1| 2| | user1|ACTIVE| 34| 1| 1| 2| | user1|ACTIVE| 55| 1| 2| 2| | user1|ACTIVE| 81| 1| 3| 2| | user1|ACTIVE| 112| 1| 4| 2| | user1|ACTIVE| 148| 1| 5| 2| | user1|ACTIVE| 189| 1| 6| 2| | user1|ACTIVE| 235| 1| 6| 3| | user1|ACTIVE| 286| 2| 6| 3| | user1|ACTIVE| 342| 2| 7| 3| | user1|ACTIVE| 403| 2| 8| 3| | user1|ACTIVE| 469| 2| 9| 3| | user1|ACTIVE| 540| 2| 10| 3| | user1|ACTIVE| 616| 2| 11| 3| | user1|ACTIVE| 697| 2| 12| 3| | user1|ACTIVE| 783| 2| 13| 3| | user1|ACTIVE| 874| 3| 13| 3| | user1|ACTIVE| 970| 3| 14| 3| +-------+------+-------------+-----------+---------------+--------------+ only showing top 20 rows
    ;