Connecting to SQL Databases using JDBC

In this example, we’ll query MySQL using its JDBC Driver. You can use other SQL Databases as well, including (but not limited to) Postgres, and Oracle. Users need to install the corresponding JDBC driver.

The Apache Spark upstream docs mentions the various supported arguments. This guide will cover the DataFrame API syntax and how to control the parallelism of reading through the JDBC interface. This notebook will cover the scala APIs first, with the other languages at the end of the notebook.

Databricks VPCs are configured to only allow Spark clusters within them. Best practice is VPC Peering to any other infrastructure to connect to. Once vpc peering is established, users can check with the netcat utility on the cluster.

%sh nc -vz db_hostname db_port

Load your JDBC Driver onto Databricks

  • Databricks comes preloaded with JDBC libraries for mysql, but you can attach other JDBC libraries and reference them in your code
  • See our Libraries Documentation for instructions on how to install a Java JAR.

Check that the JDBC Driver is available

  • If connecting to Postgres, the class org.postgres.Driver already exists in Databricks. This checks that the class exists in your classpath. Users can use the %scala cell to test this from other notebook types such as python.
//Class.forName("org.postgresql.Driver")
Class.forName("com.mysql.jdbc.Driver")

Load Your SQL Configuration

import java.util.Properties

// Option 1: Build the parameters into a JDBC url to pass into the DataFrame APIs
val jdbcUsername = "USER_NAME"
val jdbcPassword = "PASSWORD"
val jdbcHostname = "HOSTNAME"
val jdbcPort = 3306
val jdbcDatabase ="DATABASE"
val jdbcUrl = s"jdbc:mysql://${jdbcHostname}:${jdbcPort}/${jdbcDatabase}?user=${jdbcUsername}&password=${jdbcPassword}"

// Option 2: Create a Properties() object to hold the parameters. You can create the JDBC URL without passing in the user/password parameters directly.
val connectionProperties = new Properties()
connectionProperties.put("user", "USER_NAME")
connectionProperties.put("password", "PASSWORD")

Check Connectivity to your SQL Database

import java.sql.DriverManager
val connection = DriverManager.getConnection(jdbcUrl, jdbcUsername, jdbcPassword)
connection.isClosed()

Reading data from JDBC

In this section, we’ll load data from a MySQL table that already exists. We’ll use the connectionProperties() defined above. By default, this will use a single JDBC connection to pull the table into the Spark environment. We will discuss parallel reads in a later section.

val jdbc_url = s"jdbc:mysql://${jdbcHostname}:${jdbcPort}/${jdbcDatabase}"
val employees_table = spark.read.jdbc(jdbc_url, "employees", connectionProperties)

As you can see, Spark automatically reads the schema from the MySQL table and maps its types back to Spark SQL’s types

employees_table.printSchema

We can run queries against this JDBC table:

display(employees_table.select("age", "salary").groupBy("age").avg("salary"))

Writing data to JDBC

In this section, we’ll show how to write data to MySQL to use in our examples. We’ll use an existing SparkSQL table named diamonds to push into our MySQL instances.

%sql -- quick test that this test table exists
select * from diamonds limit 5

Let’s save this into a MySQL table, which we’ll name diamonds_mysql (the name of the table in MySQL could be the same. We use a different name in this example) If there are existing column names that are reserved keywords, that can trigger an exception. Our example table has table as a column name, so we’ll rename it with the withColumnRenamed() api prior to pushing it to the JDBC interface.

// Create a dataframe from a Spark SQL Table
// Reserved words in SQL (like "table") will trigger an exception if used as column names.
// The solution is to rename the column prior to writing via jdbc
val jdbc_url = s"jdbc:mysql://${jdbcHostname}:${jdbcPort}/${jdbcDatabase}"
spark.table("diamonds").withColumnRenamed("table", "table_number")
    .write
    .jdbc(jdbc_url, "diamonds_mysql", connectionProperties)

Spark automatically creates a MySQL table with the appropriate schema determined from the DataFrame schema.

The default behavior is to create a new table and to throw an error message if a table with the same name already exists. You can use Spark SQL’s SaveMode feature to change this behavior. For example, let’s append more rows to our table:

import org.apache.spark.sql.SaveMode

spark.sql("select * from diamonds limit 10").withColumnRenamed("table", "table_number")
  .write
  .mode(SaveMode.Append) // <--- Append to the existing table
  .jdbc(jdbc_url, "diamonds_mysql", connectionProperties)

You can also overwrite the existing table:

spark.table("diamonds").withColumnRenamed("table", "table_number")
  .write
  .mode(SaveMode.Overwrite) // <--- Append to the existing table
  .jdbc(jdbc_url, "diamonds_mysql", connectionProperties)

Pushdown Query to Database Engine

You can pushdown an entire query to the database and return just the result. This will use the dbtable parameter as documented upstream:

“The JDBC table that should be read. Note that anything that is valid in a FROM clause of a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses.”
// Note: The parantheses are necessary.
val pushdown_query = "(select * from employees where emp_no < 10008) emp_alias"
val df = spark.read.jdbc(url=jdbcUrl, dbtable=pushdown_query, properties=connectionProperties)
display(df)

Managing Parallelism

JDBC Reads

Users can provide split boundaries based on the dataset’s column values.

These options only specify the parallelism on read. These options must all be specified if any of them is specified.

Note

These options only specify the parallelism of the table read. Notice that lowerBound and upperBound decide the partition stride, not for filtering the rows in table. Therefore, Spark partitions and returns all rows in the table.

We will split the table read across executors on the emp_no column using the partitionColumn, lowerBound, upperBound, numPartitions parameters.

val df = (spark.read.jdbc(url=jdbcUrl,
    table="employees",
    columnName="emp_no",
    lowerBound=1L,
    upperBound=100000L,
    numPartitions=100,
    connectionProperties=connectionProperties))
display(df)

In the Spark UI, you will see that the numPartitions dictate the number of tasks that are launched. Each task is spread across the executors and this can increase the parallelism of the reads and writes through the JDBC interface. Look at the upstream guide to look into other parameters that can help with performance such as the fetchsize option.

JDBC Writes

Spark’s partitions dictate the number of connections used to push data through the JDBC api. Users can control the parallelism by calling coalesce(N) or repartition(N) depending on the existing number of partitions. Call coalesce when reducing the number of partitions, and repartition when increasing the number of partitions.

import org.apache.spark.sql.SaveMode

val df = spark.table("diamonds")
println(df.rdd.partitions.length)

// given the number of partitions above, users can reduce the partition value by calling coalesce() or increase it by calling repartition() to manage the number of connections.
df.repartition(10).write.mode(SaveMode.Append).jdbc(jdbcUrl, "diamonds_mysql", connectionProperties)

Python Example

We will cover the same topics above using python syntax.

# setting up the jdbc url and parameters
hostname = "example.databricks.rds.amazonaws.com"
dbname = "employees"
jdbcPort = 3306

jdbcUrl = "jdbc:mysql://{0}:{1}/{2}?user={3}&password={4}".format(hostname, jdbcPort, dbname, username, password)

You can pass in a dictionary that contains the credentials and driver class similar to the scala example above.

jdbc_url = "jdbc:mysql://{0}:{1}/{2}".format(hostname, jdbcPort, dbname)
# For SQLServer, pass in the "driver" option
# driverClass = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
# Add "driver" : driverClass
connectionProperties = {
  "user" : username,
  "password" : password
}

pushdown_query = "(select * from employees where emp_no < 10008) emp_alias"
df = spark.read.jdbc(url=jdbc_url, dbtable=pushdown_query, properties=connectionProperties)
display(df)

Reading from JDBC connections across multiple workers.

df = spark.read.\
      jdbc(url=jdbcUrl, \
              table='employees',\
              column='emp_no',\
              lowerBound=1,\
              upperBound=100000, \
              numPartitions=100)
display(df)

Microsoft SQLServer Example

For Microsoft SQLServer, you will need to download and attach the JDBC driver for this specific system. Download and attach the jar to the cluster using our library workflow management tool

The following cell will test if the JDBC driver class was loaded successfully. Inspect the jar to ensure the class exists. Once the jar is attached to the cluster, detach and reattach the notebook to the cluster.

// Testing for the JDBC Driver Class
Class.forName("com.microsoft.sqlserver.jdbc.SQLServerDriver")

// Define the credentials and parameters
val (user, passwd) = get_sqlserver_creds
val hostname = "example.databricks.sqlserver.com"
val dbName = "mydb"
val jdbcPort = 1433

val jdbcUrl = (s"jdbc:sqlserver://${hostname}:${jdbcPort};database=${dbName};user=${user};password=${passwd}")

val driverClass = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
val connectionProperties = new java.util.Properties()
connectionProperties.setProperty("Driver",driverClass)

Pushdown Optimization

In addition to ingesting an entire table, you can pushdown a query to the database to leverage it for processing and only return the results.

You can prune columns and pushdown query predicates to the database with DataFrame methods.

// Explain plan with no column selection returns all columns
spark.read.jdbc(jdbcUrl, "diamonds_mysql", connectionProperties).explain(true)
// Explain plan with column selection will prune columns and just return the ones specified
// Notice that only the 3 specified columns are in the explain plan
spark.read.jdbc(jdbcUrl, "diamonds_mysql", connectionProperties).select("carat", "cut", "price").explain(true)
// You can push query predicates down too
// Notice the Filter at the top of the Physical Plan
spark.read.jdbc(jdbcUrl, "diamonds_mysql", connectionProperties).select("carat", "cut", "price").where("cut = 'Good'").explain(true)

Define SparkSQL JDBC Table

You can define a SparkSQL table or view that uses the JDBC connection underneath the covers. To see how to create a view, see the upstream docs.

CREATE TABLE jdbcTable
USING org.apache.spark.sql.jdbc
OPTIONS (
  url "jdbc:mysql://jdbcHostname:jdbcPort",
  dbtable "database.tablename",
  user 'username',
  password 'password'
)

Append data into the MySQL table using the Spark SQL API

%sql
INSERT INTO diamonds_mysql
SELECT * FROM diamonds LIMIT 10 -- append 10 records to the table
%sql
SELECT count(*) record_count FROM diamonds_mysql --count increased by 10

Overwrite data in the MySQL table using the Spark SQL API

  • This will cause MySQL to drop and create the diamonds_mysql table
%sql
INSERT OVERWRITE TABLE diamonds_mysql
SELECT carat, cut, color, clarity, depth, TABLE AS table_number, price, x, y, z FROM diamonds
%sql
SELECT count(*) record_count FROM diamonds_mysql --count returned to original value (10 less)