databricks-logo

    tws-init-scala

    (Scala)
    Loading...

    TWS initialize state store from Delta in Scala

    Create synthetic data using dbldatagen in Python

    3
    %sh
    pip install dbldatagen
    Collecting dbldatagen Downloading dbldatagen-0.4.0.post1-py3-none-any.whl.metadata (9.9 kB) Downloading dbldatagen-0.4.0.post1-py3-none-any.whl (122 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 122.8/122.8 kB 3.8 MB/s eta 0:00:00 Installing collected packages: dbldatagen Successfully installed dbldatagen-0.4.0.post1 [notice] A new release of pip is available: 24.0 -> 25.0.1 [notice] To update, run: pip install --upgrade pip
    4
    %python
    import dbldatagen as dg
    from pyspark.sql.types import *
    from datetime import datetime, timedelta
    
    # Generate synthetic test data
    data_rows = 1000 * 100
    base_time = datetime.now()
    
    df_spec = (dg.DataGenerator(spark, name="ip_activity_data", rows=data_rows)
        # IP addresses with mix of normal and suspicious
        .withColumn("ip_address", StringType(), 
            values=['192.168.1.1', '10.0.0.2', '172.16.0.3', '192.168.1.4', '10.0.0.5'])
        
        # Usernames including common attack targets
        .withColumn("username", StringType(), 
            values=['user1', 'user2', 'admin', 'root', 'administrator', 'system'])
        
        # Action types
        .withColumn("action_type", StringType(), 
            values=['login', 'api_call', 'file_access', 'data_export'])
        
        # Timestamps within last hour
        .withColumn("timestamp", TimestampType(),
            begin=base_time - timedelta(hours=1),
            end=base_time,
            random=True)
        
        # Success flag - weighted towards success
        .withColumn("success", BooleanType(), 
            values=[True, False])
        
        # Request paths
        .withColumn("request_path", StringType(),
            values=[
                '/api/v1/login',
                '/api/v1/data',
                '/admin/console',
                '/api/v1/export',
                '/api/v1/users'
            ])
        
        # User agents with mix of normal and suspicious
        .withColumn("user_agent", StringType(),
            values=[
                'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
                'Python-urllib/3.8',
                'curl/7.64.1',
                'Apache-HttpClient/4.5.13',
                'Suspicious-Bot/1.0'
            ])
        
        # Country codes
        .withColumn("country_code", StringType(),
            values=['US', 'GB', 'CN', 'RU', 'FR', 'DE'])
        
        # HTTP Methods
        .withColumn("request_method", StringType(),
            values=['GET', 'POST', 'PUT', 'DELETE'])
    )
    
    # Generate the DataFrame
    df = df_spec.build()
    
    # Write to Delta table
    table_name = dbutils.widgets.get("table_name")
    df.write.format("delta").mode("overwrite").saveAsTable(table_name)
    
    # Display sample data
    display(df.limit(5))

    Set StateStoreProvider and input table name

    // Scala code
    spark.conf.set(
      "spark.sql.streaming.stateStore.providerClass",
      "com.databricks.sql.streaming.state.RocksDBStateStoreProvider"
    )
    
    val tableName = dbutils.widgets.get("table_name")
    // Use spark.table() instead of read.format("delta").load()
    val df = spark.table(tableName)
    tableName: String = initial_state_input_table df: org.apache.spark.sql.DataFrame = [ip_address: string, username: string ... 7 more fields]

    Define stateful structs that our processor will use

    package org.apache.spark.sql.streaming
    
    // Case class to hold session state
    object IPState {
      case class SecurityEvent(
        ipAddress: String,
        username: String,
        actionType: String,
        timestamp: Long,
        success: Boolean,
        requestPath: String,
        userAgent: String,
        countryCode: String,
        requestMethod: String
      )
    
      case class SecurityMetrics(
        ipAddress: String,
        threatLevel: String,
        loginAttempts: Int,
        failedLogins: Int,
        distinctUsernameCount: Int,
        totalRequests: Long,
        lastAction: String,
        adminAttempts: Int,
        suspiciousUserAgents: Int,
        distinctCountryCount: Int
      )
    
      case class IPActivityState(
        lastTimestamp: Long,
        loginAttempts: Int,
        failedLogins: Int,
        distinctUsernameCount: Int,
        lastUsername: String,
        highSpeedAttempts: Int,
        totalRequests: Long,
        lastRequestTimestamp: Long,
        distinctPathCount: Int,
        lastPath: String,
        suspiciousUserAgents: Int,
        distinctCountryCount: Int,
        lastCountry: String,
        adminAttempts: Int
      )
      // Initial state case class
      case class IPInitialState(
        lastTimestamp: Long,
        loginAttempts: Int,
        failedLogins: Int,
        distinctUsernameCount: Int,
        lastUsername: String,
        highSpeedAttempts: Int,
        totalRequests: Long,
        lastRequestTimestamp: Long,
        distinctPathCount: Int,
        lastPath: String,
        suspiciousUserAgents: Int,
        distinctCountryCount: Int,
        lastCountry: String,
        adminAttempts: Int
      )
    }
    
    
    Warning: classes defined within packages cannot be redefined without a cluster restart. Compilation successful.

    Import our structs and necessary structs

    // Import the RocksDB state store provider
    import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
    import java.util.UUID
    import org.apache.spark.sql.streaming.StatefulProcessor
    import org.apache.spark.sql.streaming._
    
    import java.sql.Timestamp
    import org.apache.spark.sql.Encoders
    import org.apache.spark.sql.streaming.IPState._
    import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import java.util.UUID import org.apache.spark.sql.streaming.StatefulProcessor import org.apache.spark.sql.streaming._ import java.sql.Timestamp import org.apache.spark.sql.Encoders import org.apache.spark.sql.streaming.IPState._

    Define our StatefulProcessor

    class SuspiciousIPTrackingProcessor 
      extends StatefulProcessorWithInitialState[String, SecurityEvent, SecurityMetrics, IPInitialState] {
      
      @transient protected var _ipState: ValueState[IPInitialState] = _
      
      // Configuration thresholds
      val suspiciousTimeWindowMs: Long = 5 * 60 * 1000  // 5 minutes
      val rapidAttemptThresholdMs: Long = 1000  // 1 second between attempts
      val maxFailedLoginThreshold: Int = 5
      val maxDistinctUsersThreshold: Int = 3
      
      // Pattern definitions moved into class
      val suspiciousUserAgentPatterns = Array(
        "bot", "curl", "python", "wget", "script", "http-client"
      ).map(_.toLowerCase)
      
      val adminPaths = Array(
        "/admin", "/console", "/management", "/actuator", "/metrics"
      ).map(_.toLowerCase)
      
      override def init(
          outputMode: OutputMode,
          timeMode: TimeMode): Unit = {
        _ipState = getHandle.getValueState[IPInitialState](
          "ipState", 
          Encoders.product[IPInitialState],
          TTLConfig.NONE)
      }
    
      override def handleInitialState(
          key: String, 
          initialState: IPInitialState,
          timerValues: TimerValues): Unit = {
        _ipState.update(initialState)
      }
      
      override def handleInputRows(
          key: String,
          inputRows: Iterator[SecurityEvent],
          timerValues: TimerValues
      ): Iterator[SecurityMetrics] = {
        
        val currentState = Option(_ipState.get()).getOrElse(
          IPInitialState(
            lastTimestamp = 0L,
            loginAttempts = 0,
            failedLogins = 0,
            distinctUsernameCount = 0,
            lastUsername = "",
            highSpeedAttempts = 0,
            totalRequests = 0L,
            lastRequestTimestamp = 0L,
            distinctPathCount = 0,
            lastPath = "",
            suspiciousUserAgents = 0,
            distinctCountryCount = 0,
            lastCountry = "",
            adminAttempts = 0
          )
        )
        
        val results = inputRows.map { event => 
            // Calculate time since last request
            val timeSinceLastRequest = event.timestamp - currentState.lastRequestTimestamp
            
            // Check for high-speed attempts
            val newHighSpeedAttempts = if (timeSinceLastRequest < rapidAttemptThresholdMs) {
              currentState.highSpeedAttempts + 1
            } else {
              currentState.highSpeedAttempts
            }
            
            // Check for suspicious user agent
            val isSuspiciousAgent = suspiciousUserAgentPatterns.exists(pattern => 
              Option(event.userAgent).exists(_.toLowerCase.contains(pattern)))
            
            // Check for admin path access
            val isAdminPath = adminPaths.exists(pattern => 
              Option(event.requestPath).exists(_.toLowerCase.contains(pattern)))
            
            // Update distinct counts
            val newUsernameCount = if (event.username != currentState.lastUsername) {
              currentState.distinctUsernameCount + 1
            } else {
              currentState.distinctUsernameCount
            }
            
            val newPathCount = if (event.requestPath != currentState.lastPath) {
              currentState.distinctPathCount + 1
            } else {
              currentState.distinctPathCount
            }
            
            val newCountryCount = if (event.countryCode != currentState.lastCountry) {
              currentState.distinctCountryCount + 1
            } else {
              currentState.distinctCountryCount
            }
            
            // Update state
            val newState = currentState.copy(
              lastTimestamp = event.timestamp,
              loginAttempts = if (event.actionType == "login") currentState.loginAttempts + 1 else currentState.loginAttempts,
              failedLogins = if (event.actionType == "login" && !event.success) currentState.failedLogins + 1 else currentState.failedLogins,
              distinctUsernameCount = newUsernameCount,
              lastUsername = event.username,
              highSpeedAttempts = newHighSpeedAttempts,
              totalRequests = currentState.totalRequests + 1,
              lastRequestTimestamp = event.timestamp,
              distinctPathCount = newPathCount,
              lastPath = event.requestPath,
              suspiciousUserAgents = if (isSuspiciousAgent) currentState.suspiciousUserAgents + 1 else currentState.suspiciousUserAgents,
              distinctCountryCount = newCountryCount,
              lastCountry = event.countryCode,
              adminAttempts = if (isAdminPath) currentState.adminAttempts + 1 else currentState.adminAttempts
            )
            
            // Calculate threat level
            val threatLevel = {
              if (newState.failedLogins >= maxFailedLoginThreshold ||
                  newState.distinctUsernameCount >= maxDistinctUsersThreshold ||
                  newState.highSpeedAttempts >= 10 ||
                  newState.adminAttempts >= 3 ||
                  newState.distinctCountryCount >= 3) "HIGH"
              else if (newState.failedLogins >= 3 ||
                       newState.distinctUsernameCount >= 2 ||
                       newState.highSpeedAttempts >= 5 ||
                       newState.suspiciousUserAgents >= 2) "MEDIUM"
              else "LOW"
            }
            
            // Update state
            _ipState.update(newState)
            
            // Return security metrics
            SecurityMetrics(
              ipAddress = event.ipAddress,
              threatLevel = threatLevel,
              loginAttempts = newState.loginAttempts,
              failedLogins = newState.failedLogins,
              distinctUsernameCount = newState.distinctUsernameCount,
              totalRequests = newState.totalRequests,
              lastAction = event.actionType,
              adminAttempts = newState.adminAttempts,
              suspiciousUserAgents = newState.suspiciousUserAgents,
              distinctCountryCount = newState.distinctCountryCount
            )
        }.toList
        
        results.iterator
      }
    }
    defined class SuspiciousIPTrackingProcessor

    Define our input stream

    import spark.implicits._
    
    val inputStream = spark.readStream
      .format("delta")
      .table(tableName)
      .selectExpr(
        "ip_address as ipAddress",
        "username",
        "action_type as actionType",
        "timestamp",
        "success",
        "request_path as requestPath",
        "user_agent as userAgent",
        "country_code as countryCode",
        "request_method as requestMethod"
      )
      .as[SecurityEvent]
    import spark.implicits._ inputStream: org.apache.spark.sql.Dataset[org.apache.spark.sql.streaming.IPState.SecurityEvent] = [ipAddress: string, username: string ... 7 more fields]

    Define output table and checkpoint location

    val checkpointLocation = "/tmp/streaming_query/checkpoint_" + UUID.randomUUID().toString
    val outputTable = "/tmp/streaming_query/output_table_" + UUID.randomUUID().toString
    checkpointLocation: String = /Workspace/Users/eric.marnadi@databricks.com/streaming_query/checkpoint_ebd91116-aff2-441f-8bc3-0762342bf1ce outputTable: String = /Workspace/Users/eric.marnadi@databricks.com/streaming_query/output_table_e79fc2c6-1c83-473f-98a2-20a7f66de599

    Define our stateful transformation and start our query

    val initialStates = Seq(
      // IP with previous suspicious activity
      ("192.168.1.1", IPInitialState(
        lastTimestamp = System.currentTimeMillis(),
        loginAttempts = 2,
        failedLogins = 1,
        distinctUsernameCount = 2, // Multiple usernames tried
        lastUsername = "admin",
        highSpeedAttempts = 1,
        totalRequests = 5L,
        lastRequestTimestamp = System.currentTimeMillis() - 1000,
        distinctPathCount = 2,
        lastPath = "/admin/console",
        suspiciousUserAgents = 1,
        distinctCountryCount = 2,
        lastCountry = "RU",
        adminAttempts = 1
      )),
      // IP with moderate activity
      ("10.0.0.2", IPInitialState(
        lastTimestamp = System.currentTimeMillis(),
        loginAttempts = 3,
        failedLogins = 0,
        distinctUsernameCount = 1,
        lastUsername = "user1",
        highSpeedAttempts = 0,
        totalRequests = 10L,
        lastRequestTimestamp = System.currentTimeMillis() - 5000,
        distinctPathCount = 3,
        lastPath = "/api/v1/data",
        suspiciousUserAgents = 0,
        distinctCountryCount = 1,
        lastCountry = "US",
        adminAttempts = 0
      )),
      // Clean IP with minimal history
      ("172.16.0.3", IPInitialState(
        lastTimestamp = System.currentTimeMillis(),
        loginAttempts = 1,
        failedLogins = 0,
        distinctUsernameCount = 1,
        lastUsername = "user2",
        highSpeedAttempts = 0,
        totalRequests = 2L,
        lastRequestTimestamp = System.currentTimeMillis() - 10000,
        distinctPathCount = 1,
        lastPath = "/api/v1/login",
        suspiciousUserAgents = 0,
        distinctCountryCount = 1,
        lastCountry = "GB",
        adminAttempts = 0
      ))
    ).toDS()
      .groupByKey(_._1)
      .mapValues(_._2)
    initialStates: org.apache.spark.sql.KeyValueGroupedDataset[String,org.apache.spark.sql.streaming.IPState.IPInitialState] = KeyValueGroupedDataset: [key: [value: string], value: [lastTimestamp: bigint, loginAttempts: int ... 12 more field(s)]]
    val suspiciousIPStream = inputStream
      .groupByKey(_.ipAddress)
      .transformWithState(
        new SuspiciousIPTrackingProcessor(),
        TimeMode.EventTime(),
        OutputMode.Append(),
        initialStates
      )
    suspiciousIPStream: org.apache.spark.sql.Dataset[org.apache.spark.sql.streaming.IPState.SecurityMetrics] = [ipAddress: string, threatLevel: string ... 8 more fields]
    val query = suspiciousIPStream.writeStream
      .format("delta")
      .option("checkpointLocation", checkpointLocation)
      .outputMode("append")
      .start(outputTable)
    
    query.processAllAvailable()
    query.stop()
    b6084530-350d-4ec1-b714-a482d22ecb15
    Last updated: 163 days ago
    query: org.apache.spark.sql.streaming.StreamingQuery = org.apache.spark.sql.execution.streaming.StreamingQueryWrapper@45506b33
    import org.apache.spark.sql.functions._
    
    spark.read.format("delta").load(outputTable)
      .orderBy(
        // Custom threat level ordering (3 for HIGH, 2 for MEDIUM, 1 for LOW)
        expr("""CASE 
               |  WHEN threatLevel = 'HIGH' THEN 3 
               |  WHEN threatLevel = 'MEDIUM' THEN 2 
               |  WHEN threatLevel = 'LOW' THEN 1 
               |END""".stripMargin).desc,
        // Then by specific risk factors
        col("failedLogins").desc,
        col("distinctUsernameCount").desc,
        col("adminAttempts").desc,
        col("suspiciousUserAgents").desc,
        col("distinctCountryCount").desc
      )
      .show(false)
    +-----------+-----------+-------------+------------+---------------------+-------------+-----------+-------------+--------------------+--------------------+ |ipAddress |threatLevel|loginAttempts|failedLogins|distinctUsernameCount|totalRequests|lastAction |adminAttempts|suspiciousUserAgents|distinctCountryCount| +-----------+-----------+-------------+------------+---------------------+-------------+-----------+-------------+--------------------+--------------------+ |192.168.1.1|HIGH |2 |1 |3 |6 |api_call |1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |api_call |1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |file_access|1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |data_export|1 |1 |3 | |192.168.1.1|HIGH |3 |1 |3 |6 |login |1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |data_export|1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |data_export|1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |api_call |1 |1 |3 | |192.168.1.1|HIGH |3 |1 |3 |6 |login |1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |api_call |1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |api_call |1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |data_export|1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |file_access|1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |data_export|1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |api_call |1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |api_call |1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |file_access|1 |1 |3 | |192.168.1.1|HIGH |3 |1 |3 |6 |login |1 |1 |3 | |192.168.1.1|HIGH |2 |1 |3 |6 |data_export|1 |1 |3 | |192.168.1.1|HIGH |3 |1 |3 |6 |login |1 |1 |3 | +-----------+-----------+-------------+------------+---------------------+-------------+-----------+-------------+--------------------+--------------------+ only showing top 20 rows import org.apache.spark.sql.functions._
    ;