Skip to main content

DataSourceWriter

A base class for data source writers.

Data source writers are responsible for saving data to a data source. Implement this class and return an instance from DataSource.writer() to make a data source writable.

Syntax

Python
from pyspark.sql.datasource import DataSourceWriter

class MyDataSourceWriter(DataSourceWriter):
def write(self, iterator):
...

Methods

Method

Description

write(iterator)

Writes data into the data source. Called once on each executor. Accepts an iterator of Row objects and returns a WriterCommitMessage, or None if there is no commit message. This method is abstract and must be implemented.

commit(messages)

Commits the writing job using a list of commit messages collected from all executors. Invoked on the driver when all tasks run successfully.

abort(messages)

Aborts the writing job using a list of commit messages collected from all executors. Invoked on the driver when one or more tasks failed.

Notes

  • The driver collects commit messages from all executors and passes them to commit() if all tasks succeed, or to abort() if any task fails.
  • If a write task fails, its commit message will be None in the list passed to commit() or abort().

Examples

Implement a basic writer that saves rows to a file:

Python
from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage

@dataclass
class MyCommitMessage(WriterCommitMessage):
num_rows: int

class MyDataSourceWriter(DataSourceWriter):
def __init__(self, options):
self.path = options.get("path")

def write(self, iterator):
rows = list(iterator)
with open(self.path, "w") as f:
for row in rows:
f.write(str(row) + "\n")
return MyCommitMessage(num_rows=len(rows))

def commit(self, messages):
total = sum(m.num_rows for m in messages if m is not None)
print(f"Committed {total} rows")

def abort(self, messages):
print("Write job failed, performing cleanup")