Skip to main content

DataSourceArrowWriter

A base class for data source writers that process data using PyArrow's RecordBatch.

Unlike DataSourceWriter, which works with an iterator of Spark Row objects, this class is optimized for the Arrow format when writing data. It can offer better performance when interfacing with systems or libraries that natively support Arrow. Implement this class and return an instance from DataSource.writer() to make a data source writable using Arrow.

Syntax

Python
from pyspark.sql.datasource import DataSourceArrowWriter

class MyDataSourceArrowWriter(DataSourceArrowWriter):
def write(self, iterator):
...

Methods

Method

Description

write(iterator)

Writes an iterator of PyArrow RecordBatch objects to the sink. Called once on each executor. Returns a WriterCommitMessage, or None if there is no commit message. This method is abstract and must be implemented.

commit(messages)

Commits the writing job using a list of commit messages collected from all executors. Invoked on the driver when all tasks run successfully. Inherited from DataSourceWriter.

abort(messages)

Aborts the writing job using a list of commit messages collected from all executors. Invoked on the driver when one or more tasks failed. Inherited from DataSourceWriter.

Notes

  • The driver collects commit messages from all executors and passes them to commit() if all tasks succeed, or to abort() if any task fails.
  • If a write task fails, its commit message will be None in the list passed to commit() or abort().

Examples

Implement an Arrow-based writer that counts rows across all batches:

Python
from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage

@dataclass
class MyCommitMessage(WriterCommitMessage):
num_rows: int

class MyDataSourceArrowWriter(DataSourceArrowWriter):
def write(self, iterator):
total_rows = 0
for batch in iterator:
total_rows += len(batch)
return MyCommitMessage(num_rows=total_rows)

def commit(self, messages):
total = sum(m.num_rows for m in messages if m is not None)
print(f"Committed {total} rows")

def abort(self, messages):
print("Write job failed, performing cleanup")