PySpark custom data sources
Preview
PySpark custom data sources are in Public Preview in Databricks Runtime 15.2 and above. Streaming support is available in Databricks Runtime 15.3 and above.
A PySpark DataSource is created by the Python (PySpark) DataSource API, which enables reading from custom data sources and writing to custom data sinks in Apache Spark using Python. You can use PySpark custom data sources to define custom connections to data systems and implement additional functionality, to build out reusable data sources.
DataSource class
The PySpark DataSource is a base class that provides methods to create data readers and writers.
Implement the data source subclass
Depending on your use case, the following must be implemented by any subclass to make a data source either readable, writable, or both:
Property or Method |
Description |
---|---|
|
Required. The name of the data source |
|
Required. The schema of the data source to be read or written |
|
Must return a |
|
Must return a |
|
Must return a |
|
Must return a |
Note
The user-defined DataSource
, DataSourceReader
, DataSourceWriter
, DataSourceStreamReader
, DataSourceStreamWriter
, and their methods must be able to be serialized. In other words, they must be a dictionary or nested dictionary that contains a primitive type.
Register the data source
After implementing the interface, you must register it, then you can load or otherwise use it as shown in the following example:
# Register the data source
spark.dataSource.register(MyDataSourceClass)
# Read from a custom data source
spark.read.format("my_datasource_name").load().show()
Example 1: Create a PySpark DataSource for batch query
To demonstrate PySpark DataSource reader capabilities, create a data source that generates example data using the faker
Python package. For more information about faker
, see the Faker documentation.
Install the faker
package using the following command:
%pip install faker
Step 1: Define the example DataSource
First, define your new PySpark DataSource as a subclass of DataSource
with a name, schema, and reader. The reader()
method must be defined to read from a data source in a batch query.
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructType
class FakeDataSource(DataSource):
"""
An example data source for batch query using the `faker` library.
"""
@classmethod
def name(cls):
return "fake"
def schema(self):
return "name string, date string, zipcode string, state string"
def reader(self, schema: StructType):
return FakeDataSourceReader(schema, self.options)
Step 2: Implement the reader for a batch query
Next, implement the reader logic to generate example data. Use the installed faker
library to populate each field in the schema.
class FakeDataSourceReader(DataSourceReader):
def __init__(self, schema, options):
self.schema: StructType = schema
self.options = options
def read(self, partition):
# Library imports must be within the method.
from faker import Faker
fake = Faker()
# Every value in this `self.options` dictionary is a string.
num_rows = int(self.options.get("numRows", 3))
for _ in range(num_rows):
row = []
for field in self.schema.fields:
value = getattr(fake, field.name)()
row.append(value)
yield tuple(row)
Step 3: Register and use the example data source
To use the data source, register it. By default, the FakeDataSource
has three rows, and the schema includes these string
fields: name
, date
, zipcode
, state
. The following example registers, loads, and outputs the example data source with the defaults:
spark.dataSource.register(FakeDataSource)
spark.read.format("fake").load().show()
+-----------------+----------+-------+----------+
| name| date|zipcode| state|
+-----------------+----------+-------+----------+
|Christine Sampson|1979-04-24| 79766| Colorado|
| Shelby Cox|2011-08-05| 24596| Florida|
| Amanda Robinson|2019-01-06| 57395|Washington|
+-----------------+----------+-------+----------+
Only string
fields are supported, but you can specify a schema with any fields that correspond to faker
package providers’ fields to generate random data for testing and development. The following example loads the data source with name
and company
fields:
spark.read.format("fake").schema("name string, company string").load().show()
+---------------------+--------------+
|name |company |
+---------------------+--------------+
|Tanner Brennan |Adams Group |
|Leslie Maxwell |Santiago Group|
|Mrs. Jacqueline Brown|Maynard Inc |
+---------------------+--------------+
To load the data source with a custom number of rows, specify the numRows
option. The following example specifies 5 rows:
spark.read.format("fake").option("numRows", 5).load().show()
+--------------+----------+-------+------------+
| name| date|zipcode| state|
+--------------+----------+-------+------------+
| Pam Mitchell|1988-10-20| 23788| Tennessee|
|Melissa Turner|1996-06-14| 30851| Nevada|
| Brian Ramsey|2021-08-21| 55277| Washington|
| Caitlin Reed|1983-06-22| 89813|Pennsylvania|
| Douglas James|2007-01-18| 46226| Alabama|
+--------------+----------+-------+------------+
Example 2: Create PySpark DataSource for streaming read and write
To demonstrate PySpark DataSource stream reader and writer capabilities, create an example data source that generates two rows in every microbatch using the faker
Python package. For more information about faker
, see the Faker documentation.
Install the faker
package using the following command:
%pip install faker
Step 1: Define the example DataSource
First, define your new PySpark DataSource as a subclass of DataSource
with a name, schema, and methods streamReader()
and streamWriter()
.
from pyspark.sql.datasource import DataSource, DataSourceStreamReader, SimpleDataSourceStreamReader, DataSourceStreamWriter
from pyspark.sql.types import StructType
class FakeStreamDataSource(DataSource):
"""
An example data source for streaming read and write using the `faker` library.
"""
@classmethod
def name(cls):
return "fakestream"
def schema(self):
return "name string, state string"
def streamReader(self, schema: StructType):
return FakeStreamReader(schema, self.options)
# If you don't need partitioning, you can implement the simpleStreamReader method instead of streamReader.
# def simpleStreamReader(self, schema: StructType):
# return SimpleStreamReader()
def streamWriter(self, schema: StructType, overwrite: bool):
return FakeStreamWriter(self.options)
Step 2: Implement the stream reader
Next, implement the example streaming data reader that generates two rows in every microbatch. You can implement DataSourceStreamReader
, or if the data source has low throughput and doesn’t require partitioning, you can implement SimpleDataSourceStreamReader
instead. Either simpleStreamReader()
or streamReader()
must be implemented, and simpleStreamReader()
is only invoked when streamReader()
is not implemented.
DataSourceStreamReader implementation
The streamReader
instance has an integer offset that increases by 2 in every microbatch, implemented with the DataSourceStreamReader
interface.
class RangePartition(InputPartition):
def __init__(self, start, end):
self.start = start
self.end = end
class FakeStreamReader(DataSourceStreamReader):
def __init__(self, schema, options):
self.current = 0
def initialOffset(self) -> dict:
"""
Returns the initial start offset of the reader.
"""
return {"offset": 0}
def latestOffset(self) -> dict:
"""
Returns the current latest offset that the next microbatch will read to.
"""
self.current += 2
return {"offset": self.current}
def partitions(self, start: dict, end: dict):
"""
Plans the partitioning of the current microbatch defined by start and end offset. It
needs to return a sequence of :class:`InputPartition` objects.
"""
return [RangePartition(start["offset"], end["offset"])]
def commit(self, end: dict):
"""
This is invoked when the query has finished processing data before end offset. This
can be used to clean up the resource.
"""
pass
def read(self, partition) -> Iterator[Tuple]:
"""
Takes a partition as an input and reads an iterator of tuples from the data source.
"""
start, end = partition.start, partition.end
for i in range(start, end):
yield (i, str(i))
SimpleDataSourceStreamReader implementation
The SimpleStreamReader
instance is the same as the FakeStreamReader
instance that generates two rows in every batch, but implemented with the SimpleDataSourceStreamReader
interface without partitioning.
class SimpleStreamReader(SimpleDataSourceStreamReader):
def initialOffset(self):
"""
Returns the initial start offset of the reader.
"""
return {"offset": 0}
def read(self, start: dict) -> (Iterator[Tuple], dict):
"""
Takes start offset as an input, then returns an iterator of tuples and the start offset of the next read.
"""
start_idx = start["offset"]
it = iter([(i,) for i in range(start_idx, start_idx + 2)])
return (it, {"offset": start_idx + 2})
def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
"""
Takes start and end offset as inputs, then reads an iterator of data deterministically.
This is called when the query replays batches during restart or after a failure.
"""
start_idx = start["offset"]
end_idx = end["offset"]
return iter([(i,) for i in range(start_idx, end_idx)])
def commit(self, end):
"""
This is invoked when the query has finished processing data before end offset. This can be used to clean up resources.
"""
pass
Step 3: Implement the stream writer
Now implement the streaming writer. This streaming data writer writes the metadata information of each microbatch to a local path.
class SimpleCommitMessage(WriterCommitMessage):
partition_id: int
count: int
class FakeStreamWriter(DataSourceStreamWriter):
def __init__(self, options):
self.options = options
self.path = self.options.get("path")
assert self.path is not None
def write(self, iterator):
"""
Writes the data, then returns the commit message of that partition. Library imports must be within the method.
"""
from pyspark import TaskContext
context = TaskContext.get()
partition_id = context.partitionId()
cnt = 0
for row in iterator:
cnt += 1
return SimpleCommitMessage(partition_id=partition_id, count=cnt)
def commit(self, messages, batchId) -> None:
"""
Receives a sequence of :class:`WriterCommitMessage` when all write tasks have succeeded, then decides what to do with it.
In this FakeStreamWriter, the metadata of the microbatch(number of rows and partitions) is written into a JSON file inside commit().
"""
status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages))
with open(os.path.join(self.path, f"{batchId}.json"), "a") as file:
file.write(json.dumps(status) + "\n")
def abort(self, messages, batchId) -> None:
"""
Receives a sequence of :class:`WriterCommitMessage` from successful tasks when some other tasks have failed, then decides what to do with it.
In this FakeStreamWriter, a failure message is written into a text file inside abort().
"""
with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file:
file.write(f"failed in batch {batchId}")
Step 4: Register and use the example data source
To use the data source, register it. After it is regsitered, you can use it in streaming queries as a source or sink by passing a short name or full name to format()
. The following example registers the data source, then starts a query that reads from the example data source and outputs to the console:
spark.dataSource.register(FakeStreamDataSource)
query = spark.readStream.format("fakestream").load().writeStream.format("console").start()
Alternatively, the following example uses the example stream as a sink and specifies an output path:
query = spark.readStream.format("fakestream").load().writeStream.format("fake").start("/output_path")