Pular para o conteúdo principal

DataSourceArrowWriter

Uma classe base para escritores de fontes de dados que processam o uso de dados do PyArrow RecordBatch.

Ao contrário de DataSourceWriter, que funciona com um iterador de objetos Spark Row , esta classe é otimizada para o formato Arrow ao gravar dados. Ele pode oferecer melhor desempenho ao interagir com sistemas ou bibliotecas que oferecem suporte nativo ao Arrow. Implemente esta classe e retorne uma instância de DataSource.writer() para tornar uma fonte de dados gravável usando Arrow.

Sintaxe

Python
from pyspark.sql.datasource import DataSourceArrowWriter

class MyDataSourceArrowWriter(DataSourceArrowWriter):
def write(self, iterator):
...

Métodos

Método

Descrição

write(iterator)

Escreve um iterador de objetos PyArrow RecordBatch no coletor. Chamado uma vez em cada executor. Retorna WriterCommitMessage ou None se não houver mensagem de commit. Este método é abstrato e precisa ser implementado.

commit(messages)

Confirme a tarefa de escrita usando uma lista de mensagens commit coletadas de todos os executores. Invocado no driver quando todas as tarefas forem executadas com sucesso. Herdado de DataSourceWriter.

abort(messages)

Aborta a tarefa de escrita usando uma lista de mensagens commit coletadas de todos os executores. Invocado no driver quando uma ou mais tarefas falharam. Herdado de DataSourceWriter.

Notas

  • O driver coleta mensagens commit de todos os executores e as passa para commit() se todas as tarefas forem bem-sucedidas, ou para abort() se alguma tarefa falhar.
  • Se uma tarefa de escrita falhar, sua mensagem de commit será None na lista passada para commit() ou abort().

Exemplos

Implemente um gravador baseado em Arrow que conte as linhas em todos os lotes:

Python
from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage

@dataclass
class MyCommitMessage(WriterCommitMessage):
num_rows: int

class MyDataSourceArrowWriter(DataSourceArrowWriter):
def write(self, iterator):
total_rows = 0
for batch in iterator:
total_rows += len(batch)
return MyCommitMessage(num_rows=total_rows)

def commit(self, messages):
total = sum(m.num_rows for m in messages if m is not None)
print(f"Committed {total} rows")

def abort(self, messages):
print("Write job failed, performing cleanup")