databricks-logo

    tws-session-scala

    (Scala)
    Loading...

    TWS session tracking in Scala

    Create synthetic data using dbldatagen in Python

    3
    %sh
    pip install dbldatagen
    Collecting dbldatagen Using cached dbldatagen-0.4.0.post1-py3-none-any.whl.metadata (9.9 kB) Using cached dbldatagen-0.4.0.post1-py3-none-any.whl (122 kB) Installing collected packages: dbldatagen Successfully installed dbldatagen-0.4.0.post1 [notice] A new release of pip is available: 24.0 -> 25.0.1 [notice] To update, run: pip install --upgrade pip
    4
    %python
    import dbldatagen as dg
    from pyspark.sql.types import IntegerType, StringType, TimestampType
    import time
    from datetime import datetime, timedelta
    
    # Set table name
    table_name = dbutils.widgets.get("table_name")
    
    # Generate session data with user_id, action_type, and timestamp
    data_rows = 1000 * 100
    base_time = datetime.now()
    
    df_spec = (dg.DataGenerator(spark, name="session_data", rows=data_rows)
               .withColumn("user_id", StringType(), values=['user1', 'user2', 'user3', 'user4', 'user5'])
               .withColumn("action_type", StringType(), values=['login', 'page_view', 'purchase'], 
                          random=True, weights=[1, 5, 1])
               .withColumn("timestamp", TimestampType(), 
                          begin=base_time - timedelta(hours=1),
                          end=base_time,
                          random=True)
               .withColumn("session_value", IntegerType(), minValue=1, maxValue=100))
    
    df = df_spec.build()
    
    # Write to Delta table
    df.write.format("delta").mode("overwrite").saveAsTable(table_name)

    Set StateStoreProvider and input table name

    // Scala code
    spark.conf.set(
      "spark.sql.streaming.stateStore.providerClass",
      "com.databricks.sql.streaming.state.RocksDBStateStoreProvider"
    )
    
    val tableName = dbutils.widgets.get("table_name")
    // Use spark.table() instead of read.format("delta").load()
    val df = spark.table(tableName)
    tableName: String = session_tracking_input df: org.apache.spark.sql.DataFrame = [user_id: string, action_type: string ... 2 more fields]

    Define stateful structs that our processor will use

    package org.apache.spark.sql.streaming
    
    // Case class to hold session state
    object SessionState {
      case class SessionState(
        lastTimestamp: Long,
        loginCount: Int,
        pageViewCount: Int,
        purchaseCount: Int,
        sessionValue: Long
      )
    }
    
    
    Warning: classes defined within packages cannot be redefined without a cluster restart. Compilation successful.

    Import our structs and necessary structs

    // Import the RocksDB state store provider
    import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
    import java.util.UUID
    import org.apache.spark.sql.streaming.StatefulProcessor
    import org.apache.spark.sql.streaming._
    
    import java.sql.Timestamp
    import org.apache.spark.sql.Encoders
    import org.apache.spark.sql.streaming.SessionState.SessionState
    import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import java.util.UUID import org.apache.spark.sql.streaming.StatefulProcessor import org.apache.spark.sql.streaming._ import java.sql.Timestamp import org.apache.spark.sql.Encoders import org.apache.spark.sql.streaming.SessionState.SessionState

    Define our StatefulProcessor

    class SessionTrackingProcessor
      extends StatefulProcessor[String, (String, String, Long, Int), (String, String, Long, Int, Int, Int)] {
      
      @transient protected var _sessionState: ValueState[SessionState] = _
      
      // Session timeout in milliseconds (30 minutes)
      val sessionTimeoutMs: Long = 30 * 60 * 1000
      
      override def init(
          outputMode: OutputMode,
          timeMode: TimeMode): Unit = {
        _sessionState = getHandle.getValueState[SessionState]("sessionState",
          Encoders.product[SessionState], TTLConfig.NONE)
      }
      
      override def handleInputRows(
          key: String,
          inputRows: Iterator[(String, String, Long, Int)],
          timerValues: TimerValues): Iterator[(String, String, Long, Int, Int, Int)] = {
        
        val currentState = Option(_sessionState.get()).getOrElse(
          SessionState(0L, 0, 0, 0, 0L)
        )
        
        val results = inputRows.map { case (userId, actionType, timestamp, value) =>
          // Update counters based on action type
          val newState = actionType match {
            case "login" => currentState.copy(
              lastTimestamp = timestamp,
              loginCount = currentState.loginCount + 1,
              sessionValue = currentState.sessionValue + value
            )
            case "page_view" => currentState.copy(
              lastTimestamp = timestamp,
              pageViewCount = currentState.pageViewCount + 1,
              sessionValue = currentState.sessionValue + value
            )
            case "purchase" => currentState.copy(
              lastTimestamp = timestamp,
              purchaseCount = currentState.purchaseCount + 1,
              sessionValue = currentState.sessionValue + value
            )
            case _ => currentState.copy(lastTimestamp = timestamp)
          }
          
          // Clear old timer if exists
          if (currentState.lastTimestamp > 0) {
            getHandle.deleteTimer(currentState.lastTimestamp + sessionTimeoutMs)
          }
          
          // Register new timer
          val newTimerTime = timestamp + sessionTimeoutMs
          getHandle.registerTimer(newTimerTime)
          
          // Update state
          _sessionState.update(newState)
          
          // Return current session status
          (userId, "ACTIVE", newState.sessionValue, 
           newState.loginCount, newState.pageViewCount, newState.purchaseCount)
        }.toList
        
        results.iterator
      }
      
      override def handleExpiredTimer(
          key: String,
          timerValues: TimerValues,
          expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, Long, Int, Int, Int)] = {
        
        // Get final state before clearing
        val finalState = Option(_sessionState.get()).getOrElse(
          SessionState(0L, 0, 0, 0, 0L)
        )
        
        // Clear session state
        _sessionState.clear()
        
        // Return final session summary
        Iterator((key, "EXPIRED", finalState.sessionValue, 
                 finalState.loginCount, finalState.pageViewCount, finalState.purchaseCount))
      }
    }
    defined class SessionTrackingProcessor

    Define our input stream

    // Read from Delta table
    val inputStream = spark.readStream
      .format("delta")
      .table(tableName)
      .as[(String, String, Long, Int)]
      .map(row => (row._1, row._2, row._3, row._4))
    inputStream: org.apache.spark.sql.Dataset[(String, String, Long, Int)] = [_1: string, _2: string ... 2 more fields]

    Define output table and checkpoint location

    val checkpointLocation = "/tmp/streaming_query/checkpoint_" + UUID.randomUUID().toString
    val outputTable = "/tmp/streaming_query/output_table_" + UUID.randomUUID().toString
    checkpointLocation: String = /Workspace/Users/eric.marnadi@databricks.com/streaming_query/checkpoint_7e59eeb1-f9e2-49e7-9f2b-d97bbca13268 outputTable: String = /Workspace/Users/eric.marnadi@databricks.com/streaming_query/output_table_8a30fda6-cb3a-4650-9c28-8d76352b18aa

    Define our stateful transformation and start our query

    import spark.implicits._
    
    val sessionStream = inputStream
      .groupByKey(_._1)  // Group by user_id
      .transformWithState(
        new SessionTrackingProcessor(),
        TimeMode.EventTime(),
        OutputMode.Append())
    import spark.implicits._ sessionStream: org.apache.spark.sql.Dataset[(String, String, Long, Int, Int, Int)] = [_1: string, _2: string ... 4 more fields]
    val query = sessionStream.writeStream
      .format("delta")
      .option("checkpointLocation", checkpointLocation)
      .outputMode("append")
      .start(outputTable)
    
    query.processAllAvailable()
    query.stop()
    9d488b56-4285-4969-aa59-2bab14fb56dc
    Last updated: 162 days ago
    query: org.apache.spark.sql.streaming.StreamingQuery = org.apache.spark.sql.execution.streaming.StreamingQueryWrapper@1d9aefb1
    spark.read.format("delta").load(outputTable).show()
    +-----+------+---+---+---+---+ | _1| _2| _3| _4| _5| _6| +-----+------+---+---+---+---+ |user1|ACTIVE| 1| 0| 0| 1| |user1|ACTIVE| 6| 1| 0| 0| |user1|ACTIVE| 11| 0| 1| 0| |user1|ACTIVE| 16| 1| 0| 0| |user1|ACTIVE| 21| 1| 0| 0| |user1|ACTIVE| 26| 0| 1| 0| |user1|ACTIVE| 31| 0| 1| 0| |user1|ACTIVE| 36| 1| 0| 0| |user1|ACTIVE| 41| 0| 1| 0| |user1|ACTIVE| 46| 0| 0| 1| |user1|ACTIVE| 51| 0| 1| 0| |user1|ACTIVE| 56| 0| 1| 0| |user1|ACTIVE| 61| 0| 1| 0| |user1|ACTIVE| 66| 1| 0| 0| |user1|ACTIVE| 71| 0| 1| 0| |user1|ACTIVE| 76| 0| 0| 1| |user1|ACTIVE| 81| 0| 1| 0| |user1|ACTIVE| 86| 0| 1| 0| |user1|ACTIVE| 91| 0| 1| 0| |user1|ACTIVE| 96| 0| 1| 0| +-----+------+---+---+---+---+ only showing top 20 rows
    ;