databricks-logo

    tws-scd2-python

    (Python)
    Loading...

    TWS Python SCD Type 2

    2
    %sh
    pip install dbldatagen
    Collecting dbldatagen Downloading dbldatagen-0.4.0.post1-py3-none-any.whl.metadata (9.9 kB) Downloading dbldatagen-0.4.0.post1-py3-none-any.whl (122 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 122.8/122.8 kB 4.5 MB/s eta 0:00:00 Installing collected packages: dbldatagen Successfully installed dbldatagen-0.4.0.post1 [notice] A new release of pip is available: 24.0 -> 25.0.1 [notice] To update, run: pip install --upgrade pip
    3
    import dbldatagen as dg
    import time
    import datetime
    from pyspark.sql.types import IntegerType, FloatType, StringType, TimestampType
    
    # Set table name
    table_name = f"synthetic_data_{int(time.time())}"
    print(f"table_name: {table_name}")
    
    # Generate session data with user_id, action_type, and timestamp
    data_rows = 1000 * 100
    now = datetime.datetime.now()
    one_hour_ago = now - datetime.timedelta(hours=1)
    
    df_spec = (dg.DataGenerator(spark, name="session_data", rows=data_rows)
               .withColumn("user_id", StringType(), values=['user1', 'user2', 'user3', 'user4', 'user5'])
               .withColumn("time", TimestampType(), data_range=(one_hour_ago, now), random=True)
               .withColumn("location", StringType(), values=['a', 'b', 'c', 'd', 'e', 'f', 'g']))
    
                                
    df = df_spec.build()
    
    # Write to Delta table
    df.write.format("delta").mode("overwrite").saveAsTable(table_name)
    spark.conf.set(
      "spark.sql.streaming.stateStore.providerClass",
      "com.databricks.sql.streaming.state.RocksDBStateStoreProvider"
    )
    from pyspark.sql.functions import col
    print(table_name)
    df = spark.table(table_name).withColumn("time", col("time").cast("long"))
    display(df)
    synthetic_data_1739556727

    Define output schema

    from pyspark.sql.types import StructType, StructField, StringType, LongType
    
    output_schema = StructType([
        StructField("user", StringType(), True),
        StructField("version", LongType(), True),
        StructField("start_time", TimestampType(), True),
        StructField("end_time", TimestampType(), True),
        StructField("location", StringType(), True)
    ])
    
    

    Define stateful processor

    import pandas as pd
    from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
    from pyspark.sql.types import StructType, StructField, LongType, StringType, TimestampType
    from typing import Iterator
    from datetime import datetime
    
    # Use RocksDB as the state store provider
    spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
    
    class SCDType2StatefulProcessor(StatefulProcessor):
        def init(self, handle: StatefulProcessorHandle) -> None:
            """Initialize stateful storage for tracking SCD Type 2 changes."""
            state_schema = StructType([
                StructField("user", StringType(), True),
                StructField("version", LongType(), True),
                StructField("start_time", TimestampType(), True),
                StructField("end_time", TimestampType(), True),
                StructField("location", StringType(), True)
            ])
            self.latest_version = handle.getValueState("latestVersion", state_schema)
    
        def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
            """Handles incoming records and maintains SCD Type 2 history."""
            current_time = datetime.utcnow()
            max_row = None
            max_time = float('-inf')
    
            for pdf in rows:
                for _, pd_row in pdf.iterrows():
                    time_value = pd_row["time"]
                    if time_value > max_time:
                        max_time = time_value
                        max_row = tuple(pd_row)
    
            if max_row is None:
                yield pd.DataFrame()
                return
    
            user_id, timestamp, location = max_row[0], max_row[1], max_row[2]
            exists = self.latest_version.exists()
    
            if exists:
                latest_entry = self.latest_version.get()
                latest_location = latest_entry[4]  # Get stored location
                latest_version = latest_entry[1]  # Get stored version number
    
                if location != latest_location:  # If location has changed, create a new version
                    # Close previous record by setting end_time
                    closed_entry = {
                        "user": latest_entry[0],
                        "version": latest_entry[1],
                        "start_time": latest_entry[2],
                        "end_time": current_time,  # Marking old record as expired
                        "location": latest_entry[4]
                    }
    
                    # Create a new version entry
                    new_entry = {
                        "user": user_id,
                        "version": latest_version + 1,
                        "start_time": current_time,
                        "end_time": None,
                        "location": location
                    }
    
                    # Update state with the new version
                    self.latest_version.update(tuple(new_entry.values()))
    
                    yield pd.DataFrame([closed_entry, new_entry])
                else:
                    yield pd.DataFrame()  # No change, so no new output
            else:
                # First entry for this user, initialize state
                new_entry = {
                    "user": user_id,
                    "version": 1,
                    "start_time": current_time,
                    "end_time": None,
                    "location": location
                }
                self.latest_version.update(tuple(new_entry.values()))
                yield pd.DataFrame([new_entry])
    
        def close(self) -> None:
            pass
    
    
    import uuid
    
    base_path = f"/Workspace/Users/bo.gao@databricks.com/tws/{uuid.uuid4()}"
    checkpoint_dir = base_path + "/checkpoint"
    output_path = base_path + "/output"
    from pyspark.sql.functions import col
    
    q = spark \
        .readStream \
        .format("delta") \
        .option("maxFilesPerTrigger", "1") \
        .table(table_name) \
        .withColumn("time", col("time").cast("long")) \
        .groupBy("user_id") \
        .transformWithStateInPandas( \
            statefulProcessor=SCDType2StatefulProcessor(), \
            outputStructType=output_schema, \
            outputMode="Append", \
            timeMode="ProcessingTime", \
        ) \
        .writeStream \
        .format("delta") \
        .option("path", output_path) \
        .option("checkpointLocation", checkpoint_dir) \
        .outputMode("append") \
        .trigger(availableNow=True) \
        .start()
    19ea6391-5968-48ee-87ab-024273ad3d92
    Last updated: 163 days ago
    q.stop()
    output_df = spark.read.format("delta").load(output_path)
    display(output_df)

    (Optional) Drop the test table and delete output/ckp path

    spark.sql(f"DROP TABLE IF EXISTS {table_name}")
    DataFrame[]
    dbutils.fs.rm(base_path, True)
    True
    ;