Skip to main content

Example stateful applications

This article contains code examples for custom stateful applications. Databricks recommends using built-in stateful methods for common operations such as aggregations and joins.

The patterns in this article use the transformWithState operator and associated classes available in Public Preview in Databricks Runtime 16.2 and above. See Build a custom stateful application.

note

Python uses the transformWithStateInPandas operator to provide the same functionality. The examples below provide code in Python and Scala.

Requirements

The transformWithState operator and the related APIs and classes have the following requirements:

  • Available in Databricks Runtime 16.2 and above.
  • Compute must use dedicated or no-isolation access mode.
  • You must use the RocksDB state store provider. Databricks recommends enabling RocksDB as part of the compute configuration.
  • transformWithStateInPandas supports standard access mode in Databricks Runtime 16.3 and above.
note

To enable the RocksDB state store provider for the current session, run the following:

Python
spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

Slowly changing dimension (SCD) type 1

The following code is an example of implementing SCD type 1 using transformWithState. SCD type 1 only tracks the most recent value for a given field.

note

You can use Streaming tables and APPLY CHANGES INTO to implement SCD type 1 or type 2 using Delta Lake-backed tables. This example implements SCD type 1 in the state store, which provides lower latency for near real-time applications.

Python
# Import necessary libraries
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator

# Set the state store provider to RocksDB
spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

# Define the output schema for the streaming query
output_schema = StructType([
StructField("user", StringType(), True),
StructField("time", LongType(), True),
StructField("location", StringType(), True)
])

# Define a custom StatefulProcessor for slowly changing dimension type 1 (SCD1) operations
class SCDType1StatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
# Define the schema for the state value
value_state_schema = StructType([
StructField("user", StringType(), True),
StructField("time", LongType(), True),
StructField("location", StringType(), True)
])
# Initialize the state to store the latest location for each user
self.latest_location = handle.getValueState("latestLocation", value_state_schema)

def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]:
# Find the row with the maximum time value
max_row = None
max_time = float('-inf')
for pdf in rows:
for _, pd_row in pdf.iterrows():
time_value = pd_row["time"]
if time_value > max_time:
max_time = time_value
max_row = tuple(pd_row)

# Check if state exists and update if necessary
exists = self.latest_location.exists()
if not exists or max_row[1] > self.latest_location.get()[1]:
# Update the state with the new max row
self.latest_location.update(max_row)
# Yield the updated row
yield pd.DataFrame(
{"user": (max_row[0],), "time": (max_row[1],), "location": (max_row[2],)}
)
# Yield an empty DataFrame if no update is needed
yield pd.DataFrame()

def close(self) -> None:
# No cleanup needed
pass

# Apply the stateful transformation to the input DataFrame
(df.groupBy("user")
.transformWithStateInPandas(
statefulProcessor=SCDType1StatefulProcessor(),
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
)
.writeStream... # Continue with stream writing configuration
)

Slowly changing dimension (SCD) type 2

The following notebooks contain an example of implementing SCD type 2 using transformWithState in Python or Scala.

SCD Type 2 Python

Open notebook in new tab

SCD Type 2 Scala

Open notebook in new tab

Downtime detector

transformWithState implements timers to allow you to take action based on elapsed time, even if no records for a given key are processed in a microbatch.

The following example implements a pattern for a downtime detector. Each time a new value is seen for a given key, it updates the lastSeen state value, clears any existing timers, and resets a timer for the future.

When a timer expires, the application emits the elapsed time since the last observed event for the key. It then sets a new timer to emit an update 10 seconds later.

Python
import datetime
import time

class DownTimeDetectorStatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
# Define schema for the state value (timestamp)
state_schema = StructType([StructField("value", TimestampType(), True)])
self.handle = handle
# Initialize state to store the last seen timestamp for each key
self.last_seen = handle.getValueState("last_seen", state_schema)

def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
latest_from_existing = self.last_seen.get()
# Calculate downtime duration
downtime_duration = timerValues.getCurrentProcessingTimeInMs() - int(time.time() * 1000)
# Register a new timer for 10 seconds in the future
self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
# Yield a DataFrame with the key and downtime duration
yield pd.DataFrame(
{
"id": key,
"timeValues": str(downtime_duration),
}
)

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
# Find the row with the maximum timestamp
max_row = max((tuple(pdf.iloc[0]) for pdf in rows), key=lambda row: row[1])

# Get the latest timestamp from existing state or use epoch start if not exists
if self.last_seen.exists():
latest_from_existing = self.last_seen.get()
else:
latest_from_existing = datetime.fromtimestamp(0)

# If new data is more recent than existing state
if latest_from_existing < max_row[1]:
# Delete all existing timers
for timer in self.handle.listTimers():
self.handle.deleteTimer(timer)
# Update the last seen timestamp
self.last_seen.update((max_row[1],))

# Register a new timer for 5 seconds in the future
self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000)

# Get current processing time in milliseconds
timestamp_in_millis = str(timerValues.getCurrentProcessingTimeInMs())

# Yield a DataFrame with the key and current timestamp
yield pd.DataFrame({"id": key, "timeValues": timestamp_in_millis})

def close(self) -> None:
# No cleanup needed
pass

Migrate existing state information

The following example demonstrates how to implement a stateful application that accepts an initial state. You can add initial state handling to any stateful application, but the initial state can only be set when first initializing the application.

This example uses the statestore reader to load existing state information from a checkpoint path. An example use case for this pattern is migrating from legacy stateful applications to transformWithState.

Python
# Import necessary libraries
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType, IntegerType
from typing import Iterator

# Set RocksDB as the state store provider for better performance
spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

"""
Input schema is as below

input_schema = StructType(
[StructField("id", StringType(), True)],
[StructField("value", StringType(), True)]
)
"""

# Define the output schema for the streaming query
output_schema = StructType([
StructField("id", StringType(), True),
StructField("accumulated", StringType(), True)
])

class AccumulatedCounterStatefulProcessorWithInitialState(StatefulProcessor):

def init(self, handle: StatefulProcessorHandle) -> None:
# Define schema for the state value (integer)
state_schema = StructType([StructField("value", IntegerType(), True)])
# Initialize state to store the accumulated counter for each id
self.counter_state = handle.getValueState("counter_state", state_schema)
self.handle = handle

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
# Check if state exists for the current key
exists = self.counter_state.exists()
if exists:
value_row = self.counter_state.get()
existing_value = value_row[0]
else:
existing_value = 0

accumulated_value = existing_value

# Process input rows and accumulate values
for pdf in rows:
value = pdf["value"].astype(int).sum()
accumulated_value += value

# Update the state with the new accumulated value
self.counter_state.update((accumulated_value,))

# Yield a DataFrame with the key and accumulated value
yield pd.DataFrame({"id": key, "accumulated": str(accumulated_value)})

def handleInitialState(self, key, initialState, timerValues) -> None:
# Initialize the state with the provided initial value
init_val = initialState.at[0, "initVal"]
self.counter_state.update((init_val,))

def close(self) -> None:
# No cleanup needed
pass

# Load initial state from a checkpoint directory
initial_state = spark.read.format("statestore")
.option("path", "$checkpointsDir")
.load()

# Apply the stateful transformation to the input DataFrame
df.groupBy("id")
.transformWithStateInPandas(
statefulProcessor=AccumulatedCounterStatefulProcessorWithInitialState(),
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
initialState=initial_state,
)
.writeStream... # Continue with stream writing configuration

Migrate Delta table to state store for intialization

The following notebooks contain an example of initializing state store values from a Delta table using transformWithState in Python or Scala.

Initialize state from Delta Python

Open notebook in new tab

Initialize state from Delta Scala

Open notebook in new tab

Session tracking

The following notebooks contain an example of session tracking using transformWithState in Python or Scala.

Session tracking Python

Open notebook in new tab

Session tracking Scala

Open notebook in new tab

Custom stream-stream join using transformWithState

The following code demonstrates a custom stream-stream join across multiple streams using transformWithState. You might use this approach instead of a built-in join operator for the following reasons:

  • You need to use the update output mode which does not support stream-stream joins. This is especially useful for lower latency applications.
  • You need to continue to perform joins for late-arriving rows (after watermark expiry).
  • You need to perform many-to-many stream-stream joins.

This example gives the user full control over state expiration logic, allowing for dynamic retention period extension to handle out-of-order events even after the watermark.

Python
# Import necessary libraries
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, StringType, TimestampType
from typing import Iterator

# Define output schema for the joined data
output_schema = StructType([
StructField("user_id", StringType(), True),
StructField("event_type", StringType(), True),
StructField("timestamp", TimestampType(), True),
StructField("profile_name", StringType(), True),
StructField("email", StringType(), True),
StructField("preferred_category", StringType(), True)
])

class CustomStreamJoinProcessor(StatefulProcessor):
# Initialize stateful storage for user profiles, preferences, and event tracking.
def init(self, handle: StatefulProcessorHandle) -> None:

# Define schemas for different types of state data
profile_schema = StructType([
StructField("name", StringType(), True),
StructField("email", StringType(), True),
StructField("updated_at", TimestampType(), True)
])
preferences_schema = StructType([
StructField("preferred_category", StringType(), True),
StructField("updated_at", TimestampType(), True)
])
activity_schema = StructType([
StructField("event_type", StringType(), True),
StructField("timestamp", TimestampType(), True)
])

# Initialize state storage for user profiles, preferences, and activity
self.profile_state = handle.getMapState("user_profiles", "string", profile_schema)
self.preferences_state = handle.getMapState("user_preferences", "string", preferences_schema)
self.activity_state = handle.getMapState("user_activity", "string", activity_schema)

# Process incoming events and update state
def handleInputRows(self, key, rows: Iterator[pd.DataFrame], timer_values) -> Iterator[pd.DataFrame]:
df = pd.concat(rows, ignore_index=True)
output_rows = []

for _, row in df.iterrows():
user_id = row["user_id"]

if "event_type" in row: # User activity event
self.activity_state.update_value(user_id, row.to_dict())
# Set a timer to process this event after a 10-second delay
self.getHandle().registerTimer(timer_values.get_current_processing_time_in_ms() + (10 * 1000))

elif "name" in row: # Profile update
self.profile_state.update_value(user_id, row.to_dict())

elif "preferred_category" in row: # Preference update
self.preferences_state.update_value(user_id, row.to_dict())

# No immediate output; processing will happen when timer expires
return iter([])

# Perform lookup after delay, handling out-of-order and late-arriving events.
def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]:

# Retrieve stored state for the user
user_activity = self.activity_state.get_value(key)
user_profile = self.profile_state.get_value(key)
user_preferences = self.preferences_state.get_value(key)

if user_activity:
# Combine data from different states into a single output row
output_row = {
"user_id": key,
"event_type": user_activity["event_type"],
"timestamp": user_activity["timestamp"],
"profile_name": user_profile.get("name") if user_profile else None,
"email": user_profile.get("email") if user_profile else None,
"preferred_category": user_preferences.get("preferred_category") if user_preferences else None
}
return iter([pd.DataFrame([output_row])])

return iter([])

def close(self) -> None:
# No cleanup needed
pass

# Apply transformWithState to the input DataFrame
(df.groupBy("user_id")
.transformWithStateInPandas(
statefulProcessor=CustomStreamJoinProcessor(),
outputStructType=output_schema,
outputMode="Append",
timeMode="ProcessingTime"
)
.writeStream... # Continue with stream writing configuration
)

Top-K computation

The following example uses a ListState with a priority queue to maintain and update the top K elements in a stream for each group key in near real-time.

Top-K Python

Open notebook in new tab

Top-K Scala

Open notebook in new tab