DataSourceArrowWriter
A base class for data source writers that process data using PyArrow's RecordBatch.
Unlike DataSourceWriter, which works with an iterator of Spark Row objects, this class is optimized for the Arrow format when writing data. It can offer better performance when interfacing with systems or libraries that natively support Arrow. Implement this class and return an instance from DataSource.writer() to make a data source writable using Arrow.
Syntax
from pyspark.sql.datasource import DataSourceArrowWriter
class MyDataSourceArrowWriter(DataSourceArrowWriter):
def write(self, iterator):
...
Methods
Method | Description |
|---|---|
| Writes an iterator of PyArrow |
| Commits the writing job using a list of commit messages collected from all executors. Invoked on the driver when all tasks run successfully. Inherited from |
| Aborts the writing job using a list of commit messages collected from all executors. Invoked on the driver when one or more tasks failed. Inherited from |
Notes
- The driver collects commit messages from all executors and passes them to
commit()if all tasks succeed, or toabort()if any task fails. - If a write task fails, its commit message will be
Nonein the list passed tocommit()orabort().
Examples
Implement an Arrow-based writer that counts rows across all batches:
from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage
@dataclass
class MyCommitMessage(WriterCommitMessage):
num_rows: int
class MyDataSourceArrowWriter(DataSourceArrowWriter):
def write(self, iterator):
total_rows = 0
for batch in iterator:
total_rows += len(batch)
return MyCommitMessage(num_rows=total_rows)
def commit(self, messages):
total = sum(m.num_rows for m in messages if m is not None)
print(f"Committed {total} rows")
def abort(self, messages):
print("Write job failed, performing cleanup")