databricks-logo

    tws-topk-scala

    (Scala)
    Loading...

    TWS Scala Top K

    Create synthetic data using dbldatagen in Python

    3
    %sh
    pip install dbldatagen
    Collecting dbldatagen Using cached dbldatagen-0.4.0.post1-py3-none-any.whl.metadata (9.9 kB) Using cached dbldatagen-0.4.0.post1-py3-none-any.whl (122 kB) Installing collected packages: dbldatagen Successfully installed dbldatagen-0.4.0.post1 [notice] A new release of pip is available: 24.0 -> 25.0.1 [notice] To update, run: pip install --upgrade pip
    4
    %python
    import dbldatagen as dg
    import time
    from pyspark.sql.types import IntegerType, FloatType, StringType, TimestampType
    from datetime import datetime, timedelta
    
    # Set table name
    table_name = f"synthetic_data_{int(time.time())}"
    dbutils.widgets.text("table_name", table_name)
    print(table_name)
    
    # Generate session data with user_id, action_type, and timestamp
    data_rows = 100
    now = datetime.now()
    one_hour_ago = now - timedelta(hours=1)
    
    df_spec = (dg.DataGenerator(spark, name="session_data", rows=data_rows)
               .withColumn("user", StringType(), values=['user1', 'user2', 'user3', 'user4', 'user5'], random=True)
               .withColumn("value", IntegerType(), minValue=0, maxValue=200)
    )
    
                                
    df = df_spec.build()
    
    # Write to Delta table
    df.write.format("delta").mode("overwrite").saveAsTable(table_name)

    Set StateStoreProvider and input table name

    // Scala code
    spark.conf.set(
      "spark.sql.streaming.stateStore.providerClass",
      "com.databricks.sql.streaming.state.RocksDBStateStoreProvider"
    )
    
    val tableName = dbutils.widgets.get("table_name")
    // Use spark.table() instead of read.format("delta").load()
    val df = spark.table(tableName)
    display(df)

    Define stateful structs that our processor will use

    Import our structs and necessary structs

    // Import the RocksDB state store provider
    import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
    import java.util.UUID
    import org.apache.spark.sql.streaming.StatefulProcessor
    import org.apache.spark.sql.streaming._
    
    import java.sql.Timestamp
    import org.apache.spark.sql.Encoders
    import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import java.util.UUID import org.apache.spark.sql.streaming.StatefulProcessor import org.apache.spark.sql.streaming._ import java.sql.Timestamp import org.apache.spark.sql.Encoders

    Define our StatefulProcessor

    import scala.collection.mutable.PriorityQueue
    
    class TopKStatefulProcessor(k: Int)
      extends StatefulProcessor[String, (String, Int), (Int, String, Int)] {
      // _topKState remembers the top-k elements from the previous micro-batch. It may have
      // less than k elements, but it will never have more than k elements.
      @transient private var _topKState: ListState[(String, Int)] = _
    
      override def init(
          outputMode: OutputMode,
          timeMode: TimeMode): Unit = {
        _topKState = getHandle.getListState[(String, Int)](
          "topKState", Encoders.tuple(Encoders.STRING, Encoders.scalaInt), TTLConfig.NONE)
      }
      override def handleInputRows(
          key: String,
          inputRows: Iterator[(String, Int)],
          timerValues: TimerValues)
          : Iterator[(Int, String, Int)] = {
        // Keep a heap of the top k elements, ordered by the Int field of (String, Int).
        val pq = new PriorityQueue[(String, Int)]()(Ordering.by(-_._2))
    
        inputRows.foreach { case (name, value) =>
          if (pq.size < k) {
            pq.enqueue((name, value))
          } else {
            // If the new element is larger than the minimum, remove the
            // minimum and add the new element.
            if (value > pq.head._2) {
              println(s"Adding to pq: ($name, $value) since $value > ${pq.head._2}")
              pq.dequeue()
              pq.enqueue((name, value))
            } else {
              println(s"Not adding to pq: ($name, $value) since $value <= ${pq.head._2}")
            }
          }
        }
    
        // Merge the top-k elements from the previous micro-batch with the top-k elements
        // from the current micro-batch.
        // For loop of k elements
        val immutableOldTopKState = _topKState.get().toList
    
        var oldTopKState = immutableOldTopKState
        // Ordered from largest to smallest
        var newTopKState = Array[(String, Int)]()
    
        var heapEntries = pq.toArray.sortBy(-_._2)
    
        var bothEmpty = false
        while (newTopKState.size < k && !bothEmpty) {
          // Grab the first of each side
          val topHeapElement = heapEntries.headOption
          val topStateElement = oldTopKState.headOption
    
          // If both are empty, we're done
          (topHeapElement, topStateElement) match {
            case (None, None) => bothEmpty = true
            case (Some(heapElement), None) =>
              newTopKState = newTopKState :+ heapElement
              heapEntries = heapEntries.tail
            case (None, Some(stateElement)) =>
              newTopKState = newTopKState :+ stateElement
              oldTopKState = oldTopKState.tail
            case (Some(heapElement), Some(stateElement)) =>
              // Prioritize the heap element if it is the same value as the state element
              if (heapElement._2 >= stateElement._2) {
                newTopKState = newTopKState :+ heapElement
                heapEntries = heapEntries.tail
              } else {
                newTopKState = newTopKState :+ stateElement
                oldTopKState = oldTopKState.tail
              }
          }
        }
        _topKState.put(newTopKState)
    
        // Emit only the top-k elements that changed.
        newTopKState.zipWithIndex.flatMap { case (record, rank) =>
          if (immutableOldTopKState.size <= rank || immutableOldTopKState(rank) != record) {
            Seq((rank + 1, record._1, record._2))
          } else {
            Seq.empty
          }
        }.iterator
      }
    }
    import scala.collection.mutable.PriorityQueue defined class TopKStatefulProcessor

    Define our input stream

    val inputStream = spark.readStream
      .format("delta")
      .option("maxFilesPerTrigger", "1")
      .table(tableName)  // Use the table name we created
      .as[(String, Int)]
    inputStream: org.apache.spark.sql.Dataset[(String, Int)] = [user: string, value: int]

    Define output table and checkpoint location

    val baseLocation = "/Workspace/Users/bo.gao@databricks.com/tws/" + UUID.randomUUID().toString
    val checkpointLocation = baseLocation + "/checkpoint"
    val outputTable = baseLocation + "/output"
    baseLocation: String = /Workspace/Users/bo.gao@databricks.com/tws/3626d05d-7581-4b9b-ac7c-967b0e585142 checkpointLocation: String = /Workspace/Users/bo.gao@databricks.com/tws/3626d05d-7581-4b9b-ac7c-967b0e585142/checkpoint outputTable: String = /Workspace/Users/bo.gao@databricks.com/tws/3626d05d-7581-4b9b-ac7c-967b0e585142/output

    Define our stateful transformation and start our query

    import spark.implicits._
    import org.apache.spark.sql.streaming.Trigger
    
    val result = inputStream
      .groupByKey(x => x._1)
      .transformWithState[(Int, String, Int)](new TopKStatefulProcessor(5),
         TimeMode.None(),
         OutputMode.Update())
    
    val query = result.writeStream
      .format("console")
      .trigger(Trigger.AvailableNow())
      .option("checkpointLocation", checkpointLocation)
      .outputMode("update")
      .start()
    0f53b8f9-fbb9-4041-a7f4-f046b43049bd
    Last updated: 163 days ago
    import spark.implicits._ import org.apache.spark.sql.streaming.Trigger result: org.apache.spark.sql.Dataset[(Int, String, Int)] = [_1: int, _2: string ... 1 more field] query: org.apache.spark.sql.streaming.StreamingQuery = org.apache.spark.sql.execution.streaming.StreamingQueryWrapper@233f7401
    dbutils.fs.rm(baseLocation, true)
    res10: Boolean = true
    ;