databricks-logo

    tws-init-python

    (Python)
    Loading...

    TWS initialize state store from Delta in Python

    2
    %sh
    pip install dbldatagen
    Collecting dbldatagen Downloading dbldatagen-0.4.0.post1-py3-none-any.whl.metadata (9.9 kB) Downloading dbldatagen-0.4.0.post1-py3-none-any.whl (122 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 122.8/122.8 kB 4.1 MB/s eta 0:00:00 Installing collected packages: dbldatagen Successfully installed dbldatagen-0.4.0.post1 [notice] A new release of pip is available: 24.0 -> 25.0.1 [notice] To update, run: pip install --upgrade pip
    3
    spark.conf.set(
      "spark.sql.streaming.stateStore.providerClass",
      "com.databricks.sql.streaming.state.RocksDBStateStoreProvider"
    )
    print(table_name)
    df = spark.table(table_name)
    4
    import dbldatagen as dg
    from pyspark.sql.types import *
    from datetime import datetime, timedelta
    import pandas as pd
    from pyspark.sql import Row
    from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
    from typing import Iterator
    import uuid
    
    # First, create synthetic data
    def generate_test_data(spark):
        data_rows = 1000 * 100
        base_time = datetime.now()
        
        df_spec = (dg.DataGenerator(spark, name="ip_activity_data", rows=data_rows)
            # IP addresses with mix of normal and suspicious
            .withColumn("ip_address", StringType(), 
                values=['192.168.1.1', '10.0.0.2', '172.16.0.3', '192.168.1.4', '10.0.0.5'])
            
            # Usernames including common attack targets
            .withColumn("username", StringType(), 
                values=['user1', 'user2', 'admin', 'root', 'administrator'])
            
            # Action types
            .withColumn("action_type", StringType(), 
                values=['login', 'api_call', 'file_access', 'data_export'])
            
            # Timestamps within last hour
            .withColumn("timestamp", TimestampType(),
                begin=base_time - timedelta(hours=1),
                end=base_time,
                random=True)
            
            # Success flag
            .withColumn("success", BooleanType(), 
                values=[True, False])
            
            # Request paths
            .withColumn("request_path", StringType(),
                values=['/api/v1/login', '/api/v1/data', '/admin/console', 
                       '/api/v1/export', '/api/v1/users'])
            
            # User agents
            .withColumn("user_agent", StringType(),
                values=['Mozilla/5.0', 'Python-urllib/3.8', 'curl/7.64.1',
                       'Apache-HttpClient/4.5.13', 'Suspicious-Bot/1.0'])
            
            # Country codes
            .withColumn("country_code", StringType(),
                values=['US', 'GB', 'CN', 'RU', 'FR'])
        )
        
        return df_spec.build()
    
    # Define our SecurityMetrics processor with initial state
    class SecurityMetricsProcessor(StatefulProcessor):
        def init(self, handle: StatefulProcessorHandle) -> None:
            # Define state schema for IP metrics
            metrics_schema = StructType([
                StructField("login_attempts", IntegerType(), True),
                StructField("failed_logins", IntegerType(), True),
                StructField("distinct_usernames", IntegerType(), True),
                StructField("total_requests", IntegerType(), True),
                StructField("admin_attempts", IntegerType(), True),
                StructField("suspicious_agents", IntegerType(), True),
                StructField("distinct_countries", IntegerType(), True)
            ])
            
            self.metrics_state = handle.getValueState("metrics_state", metrics_schema)
            
        def handleInitialState(self, key, initialState, timerValues) -> None:
            # Initialize state from provided initial values
            init_metrics = (
                initialState.at[0, "login_attempts"],
                initialState.at[0, "failed_logins"],
                initialState.at[0, "distinct_usernames"],
                initialState.at[0, "total_requests"],
                initialState.at[0, "admin_attempts"],
                initialState.at[0, "suspicious_agents"],
                initialState.at[0, "distinct_countries"]
            )
            self.metrics_state.update(init_metrics)
    
        def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
            # Get existing state or initialize if not exists
            if self.metrics_state.exists():
                current_metrics = self.metrics_state.get()
                login_attempts = current_metrics[0]
                failed_logins = current_metrics[1]
                distinct_usernames = current_metrics[2]
                total_requests = current_metrics[3]
                admin_attempts = current_metrics[4]
                suspicious_agents = current_metrics[5]
                distinct_countries = current_metrics[6]
            else:
                login_attempts = 0
                failed_logins = 0
                distinct_usernames = 0
                total_requests = 0
                admin_attempts = 0
                suspicious_agents = 0
                distinct_countries = 0
    
            # Process each batch of rows
            for pdf in rows:
                # Update metrics based on new events
                login_attempts += len(pdf[pdf['action_type'] == 'login'])
                failed_logins += len(pdf[(pdf['action_type'] == 'login') & (~pdf['success'])])
                distinct_usernames = len(pd.unique(pdf['username']))
                total_requests += len(pdf)
                admin_attempts += len(pdf[pdf['request_path'].str.contains('/admin')])
                suspicious_agents += len(pdf[pdf['user_agent'].str.contains('Bot|curl|script', case=False)])
                distinct_countries = len(pd.unique(pdf['country_code']))
    
            # Update state
            self.metrics_state.update((
                login_attempts,
                failed_logins,
                distinct_usernames,
                total_requests,
                admin_attempts,
                suspicious_agents,
                distinct_countries
            ))
    
            # Calculate threat level
            threat_level = 'LOW'
            if (failed_logins >= 5 or distinct_usernames >= 3 or 
                admin_attempts >= 3 or distinct_countries >= 3):
                threat_level = 'HIGH'
            elif (failed_logins >= 3 or distinct_usernames >= 2 or 
                  suspicious_agents >= 2):
                threat_level = 'MEDIUM'
    
            # Return metrics as DataFrame
            return iter([pd.DataFrame({
                'ip_address': [key[0]],
                'threat_level': [threat_level],
                'login_attempts': [login_attempts],
                'failed_logins': [failed_logins],
                'distinct_usernames': [distinct_usernames],
                'total_requests': [total_requests],
                'admin_attempts': [admin_attempts],
                'suspicious_agents': [suspicious_agents],
                'distinct_countries': [distinct_countries]
            })])
    
        def close(self) -> None:
            pass
    
    # Define output schema
    output_schema = StructType([
        StructField("ip_address", StringType(), True),
        StructField("threat_level", StringType(), True),
        StructField("login_attempts", IntegerType(), True),
        StructField("failed_logins", IntegerType(), True),
        StructField("distinct_usernames", IntegerType(), True),
        StructField("total_requests", IntegerType(), True),
        StructField("admin_attempts", IntegerType(), True),
        StructField("suspicious_agents", IntegerType(), True),
        StructField("distinct_countries", IntegerType(), True)
    ])
    
    # Usage example:
    import time
    # Generate and save test data
    table_name = f"security_events_{int(time.time())}"
    df = generate_test_data(spark)
    df.write.format("delta").mode("overwrite").saveAsTable(table_name)
    
    # Create initial state
    initial_data = [
        ("192.168.1.1", 2, 1, 2, 5, 1, 1, 2),  # Suspicious IP
        ("10.0.0.2", 3, 0, 1, 10, 0, 0, 1),    # Normal IP
        ("172.16.0.3", 1, 0, 1, 2, 0, 0, 1)    # New IP
    ]
    initial_state = spark.createDataFrame(
        initial_data,
        ["ip_address", "login_attempts", "failed_logins", "distinct_usernames",
          "total_requests", "admin_attempts", "suspicious_agents", "distinct_countries"]
    ).groupBy("ip_address")
    
    # Set up streaming query
    checkpoint_dir = f"/tmp/security_checkpoint_{uuid.uuid4()}"
    output_path = f"/tmp/security_output_{uuid.uuid4()}"
    
    # Create and start query
    query = (spark.readStream
        .format("delta")
        .table(table_name)
        .groupBy("ip_address")
        .transformWithStateInPandas(
            statefulProcessor=SecurityMetricsProcessor(),
            outputStructType=output_schema,
            outputMode="Update",
            timeMode="EventTime",
            initialState=initial_state
        )
        .writeStream
        .format("delta")
        .outputMode("append")
        .option("checkpointLocation", checkpoint_dir)
        .start(output_path)
    )
    
    # Wait for processing to complete
    query.processAllAvailable()
    query.stop()
    spark.read.format("delta").load(output_path).show()
    +-----------+------------+--------------+-------------+------------------+--------------+--------------+-----------------+------------------+ | ip_address|threat_level|login_attempts|failed_logins|distinct_usernames|total_requests|admin_attempts|suspicious_agents|distinct_countries| +-----------+------------+--------------+-------------+------------------+--------------+--------------+-----------------+------------------+ | 10.0.0.5| MEDIUM| 5000| 0| 1| 20000| 0| 20000| 1| | 172.16.0.3| HIGH| 5001| 0| 1| 20002| 20000| 20000| 1| |192.168.1.1| LOW| 5002| 1| 1| 20005| 1| 1| 1| |192.168.1.4| LOW| 5000| 0| 1| 20000| 0| 0| 1| | 10.0.0.2| LOW| 5003| 0| 1| 20010| 0| 0| 1| +-----------+------------+--------------+-------------+------------------+--------------+--------------+-----------------+------------------+
    ;