DataSourceStreamArrowWriter
A base class for data stream writers that process data using PyArrow's RecordBatch.
Unlike DataSourceStreamWriter, which works with an iterator of Spark Row objects, this class is optimized for the Arrow format when writing streaming data. It can offer better performance when interfacing with systems or libraries that natively support Arrow for streaming use cases. Implement this class and return an instance from DataSource.streamWriter() to make a data source writable as a streaming sink using Arrow.
Syntax
from pyspark.sql.datasource import DataSourceStreamArrowWriter
class MyDataSourceStreamArrowWriter(DataSourceStreamArrowWriter):
def write(self, iterator):
...
Methods
Method | Description |
|---|---|
| Writes an iterator of PyArrow |
| Commits the microbatch using a list of commit messages collected from all executors. Invoked on the driver when all tasks in the microbatch run successfully. Inherited from |
| Aborts the microbatch using a list of commit messages collected from all executors. Invoked on the driver when one or more tasks in the microbatch failed. Inherited from |
Notes
- The driver collects commit messages from all executors and passes them to
commit()if all tasks succeed, or toabort()if any task fails. - If a write task fails, its commit message will be
Nonein the list passed tocommit()orabort(). batchIduniquely identifies each microbatch and increments by 1 with each microbatch processed.
Examples
Implement an Arrow-based stream writer that counts rows per microbatch:
from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceStreamArrowWriter, WriterCommitMessage
@dataclass
class MyCommitMessage(WriterCommitMessage):
num_rows: int
class MyDataSourceStreamArrowWriter(DataSourceStreamArrowWriter):
def write(self, iterator):
total_rows = 0
for batch in iterator:
total_rows += len(batch)
return MyCommitMessage(num_rows=total_rows)
def commit(self, messages, batchId):
total = sum(m.num_rows for m in messages if m is not None)
print(f"Committed batch {batchId} with {total} rows")
def abort(self, messages, batchId):
print(f"Batch {batchId} failed, performing cleanup")