Skip to main content

DataSourceStreamArrowWriter

A base class for data stream writers that process data using PyArrow's RecordBatch.

Unlike DataSourceStreamWriter, which works with an iterator of Spark Row objects, this class is optimized for the Arrow format when writing streaming data. It can offer better performance when interfacing with systems or libraries that natively support Arrow for streaming use cases. Implement this class and return an instance from DataSource.streamWriter() to make a data source writable as a streaming sink using Arrow.

Syntax

Python
from pyspark.sql.datasource import DataSourceStreamArrowWriter

class MyDataSourceStreamArrowWriter(DataSourceStreamArrowWriter):
def write(self, iterator):
...

Methods

Method

Description

write(iterator)

Writes an iterator of PyArrow RecordBatch objects to the streaming sink. Called on executors once per microbatch. Returns a WriterCommitMessage, or None if there is no commit message. This method is abstract and must be implemented.

commit(messages, batchId)

Commits the microbatch using a list of commit messages collected from all executors. Invoked on the driver when all tasks in the microbatch run successfully. Inherited from DataSourceStreamWriter.

abort(messages, batchId)

Aborts the microbatch using a list of commit messages collected from all executors. Invoked on the driver when one or more tasks in the microbatch failed. Inherited from DataSourceStreamWriter.

Notes

  • The driver collects commit messages from all executors and passes them to commit() if all tasks succeed, or to abort() if any task fails.
  • If a write task fails, its commit message will be None in the list passed to commit() or abort().
  • batchId uniquely identifies each microbatch and increments by 1 with each microbatch processed.

Examples

Implement an Arrow-based stream writer that counts rows per microbatch:

Python
from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceStreamArrowWriter, WriterCommitMessage

@dataclass
class MyCommitMessage(WriterCommitMessage):
num_rows: int

class MyDataSourceStreamArrowWriter(DataSourceStreamArrowWriter):
def write(self, iterator):
total_rows = 0
for batch in iterator:
total_rows += len(batch)
return MyCommitMessage(num_rows=total_rows)

def commit(self, messages, batchId):
total = sum(m.num_rows for m in messages if m is not None)
print(f"Committed batch {batchId} with {total} rows")

def abort(self, messages, batchId):
print(f"Batch {batchId} failed, performing cleanup")