databricks-logo

    tws-scd2-scala

    (Scala)
    Loading...

    TWS Scala SCD Type 2

    Create synthetic data using dbldatagen in Python

    3
    %sh
    pip install dbldatagen
    Collecting dbldatagen Using cached dbldatagen-0.4.0.post1-py3-none-any.whl.metadata (9.9 kB) Using cached dbldatagen-0.4.0.post1-py3-none-any.whl (122 kB) Installing collected packages: dbldatagen Successfully installed dbldatagen-0.4.0.post1 [notice] A new release of pip is available: 24.0 -> 25.0.1 [notice] To update, run: pip install --upgrade pip
    4
    %python
    import dbldatagen as dg
    import time
    from pyspark.sql.types import IntegerType, FloatType, StringType, TimestampType
    from datetime import datetime, timedelta
    
    # Set table name
    table_name = f"synthetic_data_{int(time.time())}"
    dbutils.widgets.text("table_name", table_name)
    print(table_name)
    
    # Generate session data with user_id, action_type, and timestamp
    data_rows = 1000 * 100
    now = datetime.now()
    one_hour_ago = now - timedelta(hours=1)
    
    df_spec = (dg.DataGenerator(spark, name="session_data", rows=data_rows)
               .withColumn("user", StringType(), values=['user1', 'user2', 'user3', 'user4', 'user5'])
               .withColumn("time", TimestampType(), data_range=(one_hour_ago, now), random=True)
               .withColumn("location", StringType(), values=['a', 'b', 'c', 'd', 'e', 'f', 'g']))
    
                                
    df = df_spec.build()
    
    # Write to Delta table
    df.write.format("delta").mode("overwrite").saveAsTable(table_name)

    Set StateStoreProvider and input table name

    // Scala code
    spark.conf.set(
      "spark.sql.streaming.stateStore.providerClass",
      "com.databricks.sql.streaming.state.RocksDBStateStoreProvider"
    )
    
    val tableName = dbutils.widgets.get("table_name")
    // Use spark.table() instead of read.format("delta").load()
    val df = spark.table(tableName)
    display(df)

    Define stateful structs that our processor will use

    package org.apache.spark.sql.streaming
    
    import org.apache.spark.sql.Encoders
    import org.apache.spark.sql.streaming._
    import org.apache.spark.sql.types.TimestampType
    import java.sql.Timestamp
    
    
    object MS {
        case class UserLocationSCD2(
            user: String,
            version: Long,
            start_time: Timestamp,
            end_time: Option[Timestamp],
            location: String
        )
    
        case class UserLocation(
            user: String,
            time: Long,
            location: String
        )
    
    }
    Warning: classes defined within packages cannot be redefined without a cluster restart. Compilation successful.

    Import our structs and necessary structs

    // Import the RocksDB state store provider
    import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
    import java.util.UUID
    import org.apache.spark.sql.streaming.StatefulProcessor
    import org.apache.spark.sql.streaming._
    
    import java.sql.Timestamp
    import org.apache.spark.sql.Encoders
    import org.apache.spark.sql.streaming.MS._
    import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import java.util.UUID import org.apache.spark.sql.streaming.StatefulProcessor import org.apache.spark.sql.streaming._ import java.sql.Timestamp import org.apache.spark.sql.Encoders import org.apache.spark.sql.streaming.MS._

    Define our StatefulProcessor

    class SCDType2StatefulProcessor extends StatefulProcessor[String, UserLocation, UserLocationSCD2] {
      @transient private var _latestVersion: ValueState[UserLocationSCD2] = _
    
      override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
        _latestVersion = getHandle.getValueState[UserLocationSCD2](
          "scd2State",
          Encoders.product[UserLocationSCD2],
          TTLConfig.NONE
        )
      }
    
      override def handleInputRows(
          key: String,
          inputRows: Iterator[UserLocation],
          timerValues: TimerValues): Iterator[UserLocationSCD2] = {
        
        // Get the latest update for this user
        val latestIncoming = inputRows.maxBy(_.time)
        val currentTime = new Timestamp(System.currentTimeMillis())
    
        if (!_latestVersion.exists()) {
          // First record for this user, insert as version 1
          val newRecord = UserLocationSCD2(
            user = latestIncoming.user,
            version = 1,
            start_time = currentTime,
            end_time = None,
            location = latestIncoming.location
          )
          _latestVersion.update(newRecord)
          Iterator.single(newRecord)
        } else {
          val lastRecord = _latestVersion.get()
          
          if (latestIncoming.location != lastRecord.location) {
            // Close previous record and create a new version
            val closedRecord = lastRecord.copy(end_time = Some(currentTime))
            val newRecord = UserLocationSCD2(
              user = latestIncoming.user,
              version = lastRecord.version + 1,
              start_time = currentTime,
              end_time = None,
              location = latestIncoming.location
            )
            _latestVersion.update(newRecord)
            Iterator(closedRecord, newRecord)
          } else {
            // No change, return empty iterator
            Iterator.empty
          }
        }
      }
    }
    defined class SCDType2StatefulProcessor

    Define our input stream

    val inputStream = spark.readStream
      .format("delta")
      .option("maxFilesPerTrigger", "1")
      .table(tableName)  // Use the table name we created
      .as[UserLocation]
    inputStream: org.apache.spark.sql.Dataset[org.apache.spark.sql.streaming.MS.UserLocation] = [user: string, time: timestamp ... 1 more field]

    Define output table and checkpoint location

    val baseLocation = "/Workspace/Users/bo.gao@databricks.com/tws/" + UUID.randomUUID().toString
    val checkpointLocation = baseLocation + "/checkpoint"
    val outputTable = baseLocation + "/output"
    baseLocation: String = /Workspace/Users/bo.gao@databricks.com/tws/fd8d30f8-d135-4349-927d-ea6b4b56f843 checkpointLocation: String = /Workspace/Users/bo.gao@databricks.com/tws/fd8d30f8-d135-4349-927d-ea6b4b56f843/checkpoint outputTable: String = /Workspace/Users/bo.gao@databricks.com/tws/fd8d30f8-d135-4349-927d-ea6b4b56f843/output

    Define our stateful transformation and start our query

    import spark.implicits._
    
    val result = inputStream
      .groupByKey(x => x.user)
      .transformWithState[UserLocationSCD2](new SCDType2StatefulProcessor(),
         TimeMode.ProcessingTime(),
         OutputMode.Append())
    
    val query = result.writeStream
      .format("delta")
      .option("checkpointLocation", checkpointLocation)
      .option("path", outputTable)
      .outputMode("append")
      .queryName("stateful_transform")
      .start()
    Cancelled
    dbutils.fs.rm(baseLocation, true)
    res17: Boolean = true
    val outputDf = spark.read.format("delta").load(outputTable)
    display(outputDf)
    ;