Interact with Spark DataFrames(Python)

Loading...

Use LangChain to interact with Spark DataFrames

The following code showcases an example of the Spark DataFrame Agent.

Requirements

  • To use this notebook, please provide your OpenAI API Token.
  • Databricks Runtime 13.3 ML and above

Install libraries

Databricks recommends the latest version of langchain and the databricks-sql-connector.

%pip install --upgrade langchain databricks-sql-connector
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages. Requirement already satisfied: langchain in /local_disk0/.ephemeral_nfs/envs/pythonEnv-829fe94d-5128-4f57-b3b1-06a808f083ac/lib/python3.10/site-packages (0.0.216) Requirement already satisfied: databricks-sql-connector in /local_disk0/.ephemeral_nfs/envs/pythonEnv-829fe94d-5128-4f57-b3b1-06a808f083ac/lib/python3.10/site-packages (2.6.2) Requirement already satisfied: PyYAML>=5.4.1 in /databricks/python3/lib/python3.10/site-packages (from langchain) (6.0) Requirement already satisfied: numpy<2,>=1 in /databricks/python3/lib/python3.10/site-packages (from langchain) (1.21.5) Requirement already satisfied: langchainplus-sdk>=0.0.17 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-829fe94d-5128-4f57-b3b1-06a808f083ac/lib/python3.10/site-packages (from langchain) (0.0.17) Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /databricks/python3/lib/python3.10/site-packages (from langchain) (8.1.0) Requirement already satisfied: SQLAlchemy<3,>=1.4 in /databricks/python3/lib/python3.10/site-packages (from langchain) (1.4.39) Requirement already satisfied: dataclasses-json<0.6.0,>=0.5.7 in /databricks/python3/lib/python3.10/site-packages (from langchain) (0.5.8) Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /databricks/python3/lib/python3.10/site-packages (from langchain) (3.8.4) Requirement already satisfied: requests<3,>=2 in /databricks/python3/lib/python3.10/site-packages (from langchain) (2.28.1) Requirement already satisfied: numexpr<3.0.0,>=2.8.4 in /databricks/python3/lib/python3.10/site-packages (from langchain) (2.8.4) Requirement already satisfied: openapi-schema-pydantic<2.0,>=1.2 in /databricks/python3/lib/python3.10/site-packages (from langchain) (1.2.4) Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /databricks/python3/lib/python3.10/site-packages (from langchain) (4.0.2) Requirement already satisfied: pydantic<2,>=1 in /databricks/python3/lib/python3.10/site-packages (from langchain) (1.10.6) Requirement already satisfied: oauthlib<4.0.0,>=3.1.0 in /usr/lib/python3/dist-packages (from databricks-sql-connector) (3.2.0) Requirement already satisfied: alembic<2.0.0,>=1.0.11 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-829fe94d-5128-4f57-b3b1-06a808f083ac/lib/python3.10/site-packages (from databricks-sql-connector) (1.11.1) Requirement already satisfied: thrift<0.17.0,>=0.16.0 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-829fe94d-5128-4f57-b3b1-06a808f083ac/lib/python3.10/site-packages (from databricks-sql-connector) (0.16.0) Requirement already satisfied: pandas<2.0.0,>=1.2.5 in /databricks/python3/lib/python3.10/site-packages (from databricks-sql-connector) (1.4.4) Requirement already satisfied: openpyxl<4.0.0,>=3.0.10 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-829fe94d-5128-4f57-b3b1-06a808f083ac/lib/python3.10/site-packages (from databricks-sql-connector) (3.1.2) Requirement already satisfied: pyarrow>=6.0.0 in /databricks/python3/lib/python3.10/site-packages (from databricks-sql-connector) (8.0.0) Requirement already satisfied: lz4<5.0.0,>=4.0.2 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-829fe94d-5128-4f57-b3b1-06a808f083ac/lib/python3.10/site-packages (from databricks-sql-connector) (4.3.2) Requirement already satisfied: frozenlist>=1.1.1 in /databricks/python3/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.3) Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /databricks/python3/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (2.0.4) Requirement already satisfied: yarl<2.0,>=1.0 in /databricks/python3/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.2) Requirement already satisfied: multidict<7.0,>=4.5 in /databricks/python3/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.4) Requirement already satisfied: aiosignal>=1.1.2 in /databricks/python3/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1) Requirement already satisfied: attrs>=17.3.0 in /databricks/python3/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (21.4.0) Requirement already satisfied: typing-extensions>=4 in /databricks/python3/lib/python3.10/site-packages (from alembic<2.0.0,>=1.0.11->databricks-sql-connector) (4.3.0) Requirement already satisfied: Mako in /databricks/python3/lib/python3.10/site-packages (from alembic<2.0.0,>=1.0.11->databricks-sql-connector) (1.2.0) Requirement already satisfied: marshmallow-enum<2.0.0,>=1.5.1 in /databricks/python3/lib/python3.10/site-packages (from dataclasses-json<0.6.0,>=0.5.7->langchain) (1.5.1) Requirement already satisfied: marshmallow<4.0.0,>=3.3.0 in /databricks/python3/lib/python3.10/site-packages (from dataclasses-json<0.6.0,>=0.5.7->langchain) (3.19.0) Requirement already satisfied: typing-inspect>=0.4.0 in /databricks/python3/lib/python3.10/site-packages (from dataclasses-json<0.6.0,>=0.5.7->langchain) (0.9.0) Requirement already satisfied: et-xmlfile in /local_disk0/.ephemeral_nfs/envs/pythonEnv-829fe94d-5128-4f57-b3b1-06a808f083ac/lib/python3.10/site-packages (from openpyxl<4.0.0,>=3.0.10->databricks-sql-connector) (1.1.0) Requirement already satisfied: pytz>=2020.1 in /databricks/python3/lib/python3.10/site-packages (from pandas<2.0.0,>=1.2.5->databricks-sql-connector) (2022.1) Requirement already satisfied: python-dateutil>=2.8.1 in /databricks/python3/lib/python3.10/site-packages (from pandas<2.0.0,>=1.2.5->databricks-sql-connector) (2.8.2) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /databricks/python3/lib/python3.10/site-packages (from requests<3,>=2->langchain) (1.26.11) Requirement already satisfied: certifi>=2017.4.17 in /databricks/python3/lib/python3.10/site-packages (from requests<3,>=2->langchain) (2022.9.14) Requirement already satisfied: idna<4,>=2.5 in /databricks/python3/lib/python3.10/site-packages (from requests<3,>=2->langchain) (3.3) Requirement already satisfied: greenlet!=0.4.17 in /databricks/python3/lib/python3.10/site-packages (from SQLAlchemy<3,>=1.4->langchain) (1.1.1) Requirement already satisfied: six>=1.7.2 in /usr/lib/python3/dist-packages (from thrift<0.17.0,>=0.16.0->databricks-sql-connector) (1.16.0) Requirement already satisfied: packaging>=17.0 in /databricks/python3/lib/python3.10/site-packages (from marshmallow<4.0.0,>=3.3.0->dataclasses-json<0.6.0,>=0.5.7->langchain) (21.3) Requirement already satisfied: mypy-extensions>=0.3.0 in /databricks/python3/lib/python3.10/site-packages (from typing-inspect>=0.4.0->dataclasses-json<0.6.0,>=0.5.7->langchain) (0.4.3) Requirement already satisfied: MarkupSafe>=0.9.2 in /databricks/python3/lib/python3.10/site-packages (from Mako->alembic<2.0.0,>=1.0.11->databricks-sql-connector) (2.0.1) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /databricks/python3/lib/python3.10/site-packages (from packaging>=17.0->marshmallow<4.0.0,>=3.3.0->dataclasses-json<0.6.0,>=0.5.7->langchain) (3.0.9) Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
dbutils.library.restartPython()
import os
os.environ["OPENAI_API_KEY"] = ""

Spark DataFrame Agent

The Spark DataFrame agent allows you to interact with a spark dataframe as needed. You simply call create_spark_dataframe_agent with an LLM and the dataframe in question.

from langchain.llms import OpenAI
from langchain.agents import create_spark_dataframe_agent

df = spark.read.csv("/databricks-datasets/COVID/coronavirusdataset/Region.csv", header=True, inferSchema=True)
display(df)
 
code
province
city
latitude
longitude
elementary_school_count
kindergarten_count
university_count
academy_ratio
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
10000
Seoul
Seoul
37.566953
126.977977
607
830
48
1.44
10010
Seoul
Gangnam-gu
37.518421
127.047222
33
38
0
4.18
10020
Seoul
Gangdong-gu
37.530492
127.123837
27
32
0
1.54
10030
Seoul
Gangbuk-gu
37.639938
127.025508
14
21
0
0.67
10040
Seoul
Gangseo-gu
37.551166
126.849506
36
56
1
1.17
10050
Seoul
Gwanak-gu
37.47829
126.951502
22
33
1
0.89
10060
Seoul
Gwangjin-gu
37.538712
127.082366
22
33
3
1.16
10070
Seoul
Guro-gu
37.495632
126.88765
26
34
3
1
10080
Seoul
Geumcheon-gu
37.456852
126.895229
18
19
0
0.96
10090
Seoul
Nowon-gu
37.654259
127.056294
42
66
6
1.39
10100
Seoul
Dobong-gu
37.668952
127.047082
23
26
1
0.95
10110
Seoul
Dongdaemun-gu
37.574552
127.039721
21
31
4
1.06
10120
Seoul
Dongjak-gu
37.510571
126.963604
21
34
3
1.17
10130
Seoul
Mapo-gu
37.566283
126.901644
22
24
2
1.83
10140
Seoul
Seodaemun-gu
37.579428
126.936771
19
25
6
1.12
10150
Seoul
Seocho-gu
37.483804
127.032693
24
27
1
2.6
10160
Seoul
Seongdong-gu
37.563277
127.036647
21
30
2
0.97
244 rows
agent = create_spark_dataframe_agent(llm=OpenAI(temperature=0), df=df, verbose=True)
agent.run("How many rows are there?")
> Entering new chain... Thought: I need to find out how many rows are in the dataframe Action: python_repl_ast Action Input: df.count() Observation: 244 Thought: I now know the final answer Final Answer: There are 244 rows in the dataframe. > Finished chain.
Out[4]: 'There are 244 rows in the dataframe.'