Skip to main content

DataSourceStreamWriter

A base class for data stream writers.

Data stream writers are responsible for writing data to a streaming sink. Implement this class and return an instance from DataSource.streamWriter() to make a data source writable as a streaming sink. write() is called on executors for each microbatch, and commit() or abort() is called on the driver after all tasks in the microbatch complete.

Syntax

Python
from pyspark.sql.datasource import DataSourceStreamWriter

class MyDataSourceStreamWriter(DataSourceStreamWriter):
def write(self, iterator):
...

Methods

Method

Description

write(iterator)

Writes data into the streaming sink. Called on executors once per microbatch. Accepts an iterator of Row objects and returns a WriterCommitMessage, or None if there is no commit message. This method is abstract and must be implemented.

commit(messages, batchId)

Commits the microbatch using a list of commit messages collected from all executors. Invoked on the driver when all tasks in the microbatch run successfully.

abort(messages, batchId)

Aborts the microbatch using a list of commit messages collected from all executors. Invoked on the driver when one or more tasks in the microbatch failed.

Notes

  • The driver collects commit messages from all executors and passes them to commit() if all tasks succeed, or to abort() if any task fails.
  • If a write task fails, its commit message will be None in the list passed to commit() or abort().
  • batchId uniquely identifies each microbatch and increments by 1 with each microbatch processed.

Examples

Implement a stream writer that appends rows to a file:

Python
from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceStreamWriter, WriterCommitMessage

@dataclass
class MyCommitMessage(WriterCommitMessage):
num_rows: int

class MyDataSourceStreamWriter(DataSourceStreamWriter):
def __init__(self, options):
self.path = options.get("path")

def write(self, iterator):
rows = list(iterator)
with open(self.path, "a") as f:
for row in rows:
f.write(str(row) + "\n")
return MyCommitMessage(num_rows=len(rows))

def commit(self, messages, batchId):
total = sum(m.num_rows for m in messages if m is not None)
print(f"Committed batch {batchId} with {total} rows")

def abort(self, messages, batchId):
print(f"Batch {batchId} failed, performing cleanup")