Unit testing for notebooks

You can use unit testing to help improve the quality and consistency of your notebooks’ code. Unit testing is an approach to testing self-contained units of code, such as functions, early and often. This helps you find problems with your code faster, uncover mistaken assumptions about your code sooner, and streamline your overall coding efforts.

This article is an introduction to basic unit testing with functions. Advanced concepts such as unit testing classes and interfaces, as well as the use of stubs, mocks, and test harnesses, while also supported when unit testing for notebooks, are outside the scope of this article. This article also does not cover other kinds of testing methods, such as integration testing, system testing, acceptance testing, or non-functional testing methods such as performance testing or usability testing.

This article demonstrates the following:

  • How to organize functions and their unit tests.

  • How to write functions in Python, R, Scala, as well as user-defined functions in SQL, that are well-designed to be unit tested.

  • How to call these functions from Python, R, Scala, and SQL notebooks.

  • How to write unit tests in Python, R, and Scala by using the popular test frameworks pytest for Python, testthat for R, and ScalaTest for Scala. Also how to write SQL that unit tests SQL user-defined functions (SQL UDFs).

  • How to run these unit tests from Python, R, Scala, and SQL notebooks.

Organize functions and unit tests

There are a few common approaches for organizing your functions and their unit tests with notebooks. Each approach has its benefits and challenges.

For Python, R, and Scala notebooks, common approaches include the following:

  • Store functions and their unit tests outside of notebooks..

    • Benefits: You can call these functions with and outside of notebooks. Test frameworks are better designed to run tests outside of notebooks.

    • Challenges: This approach is not supported for Scala notebooks. This approach also increases the number of files to track and maintain.

  • Store functions in one notebook and their unit tests in a separate notebook..

    • Benefits: These functions are easier to reuse across notebooks.

    • Challenges: The number of notebooks to track and maintain increases. These functions cannot be used outside of notebooks. These functions can also be more difficult to test outside of notebooks.

  • Store functions and their unit tests within the same notebook..

    • Benefits: Functions and their unit tests are stored within a single notebook for easier tracking and maintenance.

    • Challenges: These functions can be more difficult to reuse across notebooks. These functions cannot be used outside of notebooks. These functions can also be more difficult to test outside of notebooks.

For Python and R notebooks, Databricks recommends storing functions and their unit tests outside of notebooks. For Scala notebooks, Databricks recommends including functions in one notebook and their unit tests in a separate notebook.

For SQL notebooks, Databricks recommends that you store functions as SQL user-defined functions (SQL UDFs) in your schemas (also known as databases). You can then call these SQL UDFs and their unit tests from SQL notebooks.

Write functions

This section describes a simple set of example functions that determine the following:

  • Whether a table exists in a database.

  • Whether a column exists in a table.

  • How many rows exist in a column for a value within that column.

These functions are intended to be simple, so that you can focus on the unit testing details in this article rather than focus on the functions themselves.

To get the best unit testing results, a function should return a single predictable outcome and be of a single data type. For example, to check whether something exists, the function should return a boolean value of true or false. To return the number of rows that exist, the function should return a non-negative, whole number. It should not, in the first example, return either false if something does not exist or the thing itself if it does exist. Likewise, for the second example, it should not return either the number of rows that exist or false if no rows exist.

You can add these functions to an existing Databricks workspace as follows, in Python, R, Scala, or SQL.

The following code assumes you have Set up Databricks Git folders (Repos), added a repo, and have the repo open in your Databricks workspace.

Create a file named myfunctions.py within the repo, and add the following contents to the file. Other examples in this article expect this file to be named myfunctions.py. You can use different names for your own files.

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Because this file is not a Databricks notebook, you
# must create a Spark session. Databricks notebooks
# create a Spark session for you by default.
spark = SparkSession.builder \
                    .appName('integrity-tests') \
                    .getOrCreate()

# Does the specified table exist in the specified database?
def tableExists(tableName, dbName):
  return spark.catalog.tableExists(f"{dbName}.{tableName}")

# Does the specified column exist in the given DataFrame?
def columnExists(dataFrame, columnName):
  if columnName in dataFrame.columns:
    return True
  else:
    return False

# How many rows are there for the specified value in the specified column
# in the given DataFrame?
def numRowsInColumnForValue(dataFrame, columnName, columnValue):
  df = dataFrame.filter(col(columnName) == columnValue)

  return df.count()

The following code assumes you have Set up Databricks Git folders (Repos), added a repo, and have the repo open in your Databricks workspace.

Create a file named myfunctions.r within the repo, and add the following contents to the file. Other examples in this article expect this file to be named myfunctions.r. You can use different names for your own files.

library(SparkR)

# Does the specified table exist in the specified database?
table_exists <- function(table_name, db_name) {
  tableExists(paste(db_name, ".", table_name, sep = ""))
}

# Does the specified column exist in the given DataFrame?
column_exists <- function(dataframe, column_name) {
  column_name %in% colnames(dataframe)
}

# How many rows are there for the specified value in the specified column
# in the given DataFrame?
num_rows_in_column_for_value <- function(dataframe, column_name, column_value) {
  df = filter(dataframe, dataframe[[column_name]] == column_value)

  count(df)
}

Create a Scala notebook named myfunctions with the following contents. Other examples in this article expect this notebook to be named myfunctions. You can use different names for your own notebooks.

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col

// Does the specified table exist in the specified database?
def tableExists(tableName: String, dbName: String) : Boolean = {
  return spark.catalog.tableExists(dbName + "." + tableName)
}

// Does the specified column exist in the given DataFrame?
def columnExists(dataFrame: DataFrame, columnName: String) : Boolean = {
  val nameOfColumn = null

  for(nameOfColumn <- dataFrame.columns) {
    if (nameOfColumn == columnName) {
      return true
    }
  }

  return false
}

// How many rows are there for the specified value in the specified column
// in the given DataFrame?
def numRowsInColumnForValue(dataFrame: DataFrame, columnName: String, columnValue: String) : Long = {
  val df = dataFrame.filter(col(columnName) === columnValue)

  return df.count()
}

The following code assumes you have the third-party sample dataset diamonds within a schema named default within a catalog named main that is accessible from your Databricks workspace. If the catalog or schema that you want to use has a different name, then change one or both of the following USE statements to match.

Create a SQL notebook and add the following contents to this new notebook. Then attach the notebook to a cluster and run the notebook to add the following SQL UDFs to the specified catalog and schema.

Note

The SQL UDFs table_exists and column_exists work only with Unity Catalog. SQL UDF support for Unity Catalog is in Public Preview.

USE CATALOG main;
USE SCHEMA default;

CREATE OR REPLACE FUNCTION table_exists(catalog_name STRING,
                                        db_name      STRING,
                                        table_name   STRING)
  RETURNS BOOLEAN
  RETURN if(
    (SELECT count(*) FROM system.information_schema.tables
     WHERE table_catalog = table_exists.catalog_name
       AND table_schema  = table_exists.db_name
       AND table_name    = table_exists.table_name) > 0,
    true,
    false
  );

CREATE OR REPLACE FUNCTION column_exists(catalog_name STRING,
                                         db_name      STRING,
                                         table_name   STRING,
                                         column_name  STRING)
  RETURNS BOOLEAN
  RETURN if(
    (SELECT count(*) FROM system.information_schema.columns
     WHERE table_catalog = column_exists.catalog_name
       AND table_schema  = column_exists.db_name
       AND table_name    = column_exists.table_name
       AND column_name   = column_exists.column_name) > 0,
    true,
    false
  );

CREATE OR REPLACE FUNCTION num_rows_for_clarity_in_diamonds(clarity_value STRING)
  RETURNS BIGINT
  RETURN SELECT count(*)
         FROM main.default.diamonds
         WHERE clarity = clarity_value

Call functions

This section describes code that calls the preceding functions. You could use these functions, for example, to count the number of rows in table where a specified value exists within a specfied column. However, you would want to check whether the table actually exists, and whether the column actually exists in that table, before you proceed. The following code checks for these conditions.

If you added the functions from the preceding section to your Databricks workspace, you can call these functions from your workspace as follows.

Create a Python notebook in the same folder as the preceding myfunctions.py file in your repo, and add the following contents to the notebook. Change the variable values for the table name, the schema (database) name, the column name, and the column value as needed. Then attach the notebook to a cluster and run the notebook to see the results.

from myfunctions import *

tableName   = "diamonds"
dbName      = "default"
columnName  = "clarity"
columnValue = "VVS2"

# If the table exists in the specified database...
if tableExists(tableName, dbName):

  df = spark.sql(f"SELECT * FROM {dbName}.{tableName}")

  # And the specified column exists in that table...
  if columnExists(df, columnName):
    # Then report the number of rows for the specified value in that column.
    numRows = numRowsInColumnForValue(df, columnName, columnValue)

    print(f"There are {numRows} rows in '{tableName}' where '{columnName}' equals '{columnValue}'.")
  else:
    print(f"Column '{columnName}' does not exist in table '{tableName}' in schema (database) '{dbName}'.")
else:
  print(f"Table '{tableName}' does not exist in schema (database) '{dbName}'.") 

Create an R notebook in the same folder as the preceding myfunctions.r file in your repo, and add the following contents to the notebook. Change the variable values for the table name, the schema (database) name, the column name, and the column value as needed. Then attach the notebook to a cluster and run the notebook to see the results.

library(SparkR)
source("myfunctions.r")

table_name   <- "diamonds"
db_name      <- "default"
column_name  <- "clarity"
column_value <- "VVS2"

# If the table exists in the specified database...
if (table_exists(table_name, db_name)) {

  df = sql(paste("SELECT * FROM ", db_name, ".", table_name, sep = ""))

  # And the specified column exists in that table...
  if (column_exists(df, column_name)) {
    # Then report the number of rows for the specified value in that column.
    num_rows = num_rows_in_column_for_value(df, column_name, column_value)

    print(paste("There are ", num_rows, " rows in table '", table_name, "' where '", column_name, "' equals '", column_value, "'.", sep = "")) 
  } else {
    print(paste("Column '", column_name, "' does not exist in table '", table_name, "' in schema (database) '", db_name, "'.", sep = ""))
  }

} else {
  print(paste("Table '", table_name, "' does not exist in schema (database) '", db_name, "'.", sep = ""))
}

Create another Scala notebook in the same folder as the preceding myfunctions Scala notebook, and add the following contents to this new notebook.

In this new notebook’s first cell, add the following code, which calls the %run magic. This magic makes the contents of the myfunctions notebook available to your new notebook.

%run ./myfunctions

In this new notebook’s second cell, add the following code. Change the variable values for the table name, the schema (database) name, the column name, and the column value as needed. Then attach the notebook to a cluster and run the notebook to see the results.

val tableName   = "diamonds"
val dbName      = "default"
val columnName  = "clarity"
val columnValue = "VVS2"

// If the table exists in the specified database...
if (tableExists(tableName, dbName)) {

  val df = spark.sql("SELECT * FROM " + dbName + "." + tableName)

  // And the specified column exists in that table...
  if (columnExists(df, columnName)) {
    // Then report the number of rows for the specified value in that column.
    val numRows = numRowsInColumnForValue(df, columnName, columnValue)

    println("There are " + numRows + " rows in '" + tableName + "' where '" + columnName + "' equals '" + columnValue + "'.")
  } else {
    println("Column '" + columnName + "' does not exist in table '" + tableName + "' in database '" + dbName + "'.")
  }

} else {
  println("Table '" + tableName + "' does not exist in database '" + dbName + "'.")
}

Add the following code to a new cell in the preceding notebook or to a cell in a separate notebook. Change the schema or catalog names if necessary to match yours, and then run this cell to see the results.

SELECT CASE
-- If the table exists in the specified catalog and schema...
WHEN
  table_exists("main", "default", "diamonds")
THEN
  -- And the specified column exists in that table...
  (SELECT CASE
   WHEN
     column_exists("main", "default", "diamonds", "clarity")
   THEN
     -- Then report the number of rows for the specified value in that column.
     printf("There are %d rows in table 'main.default.diamonds' where 'clarity' equals 'VVS2'.",
            num_rows_for_clarity_in_diamonds("VVS2"))
   ELSE
     printf("Column 'clarity' does not exist in table 'main.default.diamonds'.")
   END)
ELSE
  printf("Table 'main.default.diamonds' does not exist.")
END

Write unit tests

This section describes code that tests each of the functions that are described toward the beginning of this article. If you make any changes to functions in the future, you can use unit tests to determine whether those functions still work as you expect them to.

If you added the functions toward the beginning of this article to your Databricks workspace, you can add unit tests for these functions to your workspace as follows.

Create another file named test_myfunctions.py in the same folder as the preceding myfunctions.py file in your repo, and add the following contents to the file. By default, pytest looks for .py files whose names start with test_ (or end with _test) to test. Similarly, by default, pytest looks inside of these files for functions whose names start with test_ to test.

In general, it is a best practice to not run unit tests against functions that work with data in production. This is especially important for functions that add, remove, or otherwise change data. To protect your production data from being compromised by your unit tests in unexpected ways, you should run unit tests against non-production data. One common approach is to create fake data that is as close as possible to the production data. The following code example creates fake data for the unit tests to run against.

import pytest
import pyspark
from myfunctions import *
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, StringType

tableName    = "diamonds"
dbName       = "default"
columnName   = "clarity"
columnValue  = "SI2"

# Because this file is not a Databricks notebook, you
# must create a Spark session. Databricks notebooks
# create a Spark session for you by default.
spark = SparkSession.builder \
                    .appName('integrity-tests') \
                    .getOrCreate()

# Create fake data for the unit tests to run against.
# In general, it is a best practice to not run unit tests
# against functions that work with data in production.
schema = StructType([ \
  StructField("_c0",     IntegerType(), True), \
  StructField("carat",   FloatType(),   True), \
  StructField("cut",     StringType(),  True), \
  StructField("color",   StringType(),  True), \
  StructField("clarity", StringType(),  True), \
  StructField("depth",   FloatType(),   True), \
  StructField("table",   IntegerType(), True), \
  StructField("price",   IntegerType(), True), \
  StructField("x",       FloatType(),   True), \
  StructField("y",       FloatType(),   True), \
  StructField("z",       FloatType(),   True), \
])

data = [ (1, 0.23, "Ideal",   "E", "SI2", 61.5, 55, 326, 3.95, 3.98, 2.43 ), \
         (2, 0.21, "Premium", "E", "SI1", 59.8, 61, 326, 3.89, 3.84, 2.31 ) ]

df = spark.createDataFrame(data, schema)

# Does the table exist?
def test_tableExists():
  assert tableExists(tableName, dbName) is True

# Does the column exist?
def test_columnExists():
  assert columnExists(df, columnName) is True

# Is there at least one row for the value in the specified column?
def test_numRowsInColumnForValue():
  assert numRowsInColumnForValue(df, columnName, columnValue) > 0

Create another file named test_myfunctions.r in the same folder as the preceding myfunctions.r file in your repo, and add the following contents to the file. By default, testthat looks for .r files whose names start with test to test.

In general, it is a best practice to not run unit tests against functions that work with data in production. This is especially important for functions that add, remove, or otherwise change data. To protect your production data from being compromised by your unit tests in unexpected ways, you should run unit tests against non-production data. One common approach is to create fake data that is as close as possible to the production data. The following code example creates fake data for the unit tests to run against.

library(testthat)
source("myfunctions.r")

table_name   <- "diamonds"
db_name      <- "default"
column_name  <- "clarity"
column_value <- "SI2"

# Create fake data for the unit tests to run against.
# In general, it is a best practice to not run unit tests
# against functions that work with data in production.
schema <- structType(
  structField("_c0",     "integer"),
  structField("carat",   "float"),
  structField("cut",     "string"),
  structField("color",   "string"),
  structField("clarity", "string"),
  structField("depth",   "float"),
  structField("table",   "integer"),
  structField("price",   "integer"),
  structField("x",       "float"),
  structField("y",       "float"),
  structField("z",       "float"))

data <- list(list(as.integer(1), 0.23, "Ideal",   "E", "SI2", 61.5, as.integer(55), as.integer(326), 3.95, 3.98, 2.43),
             list(as.integer(2), 0.21, "Premium", "E", "SI1", 59.8, as.integer(61), as.integer(326), 3.89, 3.84, 2.31))

df <- createDataFrame(data, schema)

# Does the table exist?
test_that ("The table exists.", {
  expect_true(table_exists(table_name, db_name))
})

# Does the column exist?
test_that ("The column exists in the table.", {
  expect_true(column_exists(df, column_name))
})

# Is there at least one row for the value in the specified column?
test_that ("There is at least one row in the query result.", {
  expect_true(num_rows_in_column_for_value(df, column_name, column_value) > 0)
})

Create another Scala notebook in the same folder as the preceding myfunctions Scala notebook, and add the following contents to this new notebook.

In the new notebook’s first cell, add the following code, which calls the %run magic. This magic makes the contents of the myfunctions notebook available to your new notebook.

%run ./myfunctions

In the second cell, add the following code. This code defines your unit tests and specifies how to run them.

In general, it is a best practice to not run unit tests against functions that work with data in production. This is especially important for functions that add, remove, or otherwise change data. To protect your production data from being compromised by your unit tests in unexpected ways, you should run unit tests against non-production data. One common approach is to create fake data that is as close as possible to the production data. The following code example creates fake data for the unit tests to run against.

import org.scalatest._
import org.apache.spark.sql.types.{StructType, StructField, IntegerType, FloatType, StringType}
import scala.collection.JavaConverters._

class DataTests extends AsyncFunSuite {

  val tableName   = "diamonds"
  val dbName      = "default"
  val columnName  = "clarity"
  val columnValue = "SI2"

  // Create fake data for the unit tests to run against.
  // In general, it is a best practice to not run unit tests
  // against functions that work with data in production.
  val schema = StructType(Array(
                 StructField("_c0",     IntegerType),
                 StructField("carat",   FloatType),
                 StructField("cut",     StringType),
                 StructField("color",   StringType),
                 StructField("clarity", StringType),
                 StructField("depth",   FloatType),
                 StructField("table",   IntegerType),
                 StructField("price",   IntegerType),
                 StructField("x",       FloatType),
                 StructField("y",       FloatType),
                 StructField("z",       FloatType)
               ))

  val data = Seq(
                  Row(1, 0.23, "Ideal",   "E", "SI2", 61.5, 55, 326, 3.95, 3.98, 2.43),
                  Row(2, 0.21, "Premium", "E", "SI1", 59.8, 61, 326, 3.89, 3.84, 2.31)
                ).asJava

  val df = spark.createDataFrame(data, schema)

  // Does the table exist?
  test("The table exists") {
    assert(tableExists(tableName, dbName) == true)
  }

  // Does the column exist?
  test("The column exists") {
    assert(columnExists(df, columnName) == true)
  }

  // Is there at least one row for the value in the specified column?
  test("There is at least one matching row") {
    assert(numRowsInColumnForValue(df, columnName, columnValue) > 0)
  }
}

nocolor.nodurations.nostacks.stats.run(new DataTests)

Note

This code example uses the FunSuite style of testing in ScalaTest. For other available testing styles, see Selecting testing styles for your project.

Before you add unit tests, you should be aware that in general, it is a best practice to not run unit tests against functions that work with data in production. This is especially important for functions that add, remove, or otherwise change data. To protect your production data from being compromised by your unit tests in unexpected ways, you should run unit tests against non-production data. One common approach is to run unit tests against views instead of tables.

To create a view, you can call the CREATE VIEW command from a new cell in either the preceding notebook or a separate notebook. The following example assumes that you have an existing table named diamonds within a schema (database) named default within a catalog named main. Change these names to match your own as needed, and then run only that cell.

USE CATALOG main;
USE SCHEMA default;

CREATE VIEW view_diamonds AS
SELECT * FROM diamonds;

After you create the view, add each of the following SELECT statements to its own new cell in the preceding notebook or to its own new cell in a separate notebook. Change the names to match your own as needed.

SELECT if(table_exists("main", "default", "view_diamonds"),
          printf("PASS: The table 'main.default.view_diamonds' exists."),
          printf("FAIL: The table 'main.default.view_diamonds' does not exist."));

SELECT if(column_exists("main", "default", "view_diamonds", "clarity"),
          printf("PASS: The column 'clarity' exists in the table 'main.default.view_diamonds'."),
          printf("FAIL: The column 'clarity' does not exists in the table 'main.default.view_diamonds'."));

SELECT if(num_rows_for_clarity_in_diamonds("VVS2") > 0,
          printf("PASS: The table 'main.default.view_diamonds' has at least one row where the column 'clarity' equals 'VVS2'."),
          printf("FAIL: The table 'main.default.view_diamonds' does not have at least one row where the column 'clarity' equals 'VVS2'."));

Run unit tests

This section describes how to run the unit tests that you coded in the preceding section. When you run the unit tests, you get results showing which unit tests passed and failed.

If you added the unit tests from the preceding section to your Databricks workspace, you can run these unit tests from your workspace. You can run these unit tests either manually or on a schedule.

Create a Python notebook in the same folder as the preceding test_myfunctions.py file in your repo, and add the following contents.

In the new notebook’s first cell, add the following code, and then run the cell, which calls the %pip magic. This magic installs pytest.

%pip install pytest

In the second cell, add the following code, replace <my-repo-name> with the folder name for your repo, and then run the cell. Results show which unit tests passed and failed.

import pytest
import os
import sys

repo_name = "<my-repo-name>"

# Get the path to this notebook, for example "/Workspace/Repos/{username}/{repo-name}".
notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()

# Get the repo's root directory name.
repo_root = os.path.dirname(os.path.dirname(notebook_path))

# Prepare to run pytest from the repo.
os.chdir(f"/Workspace/{repo_root}/{repo_name}")
print(os.getcwd())

# Skip writing pyc files on a readonly filesystem.
sys.dont_write_bytecode = True

# Run pytest.
retcode = pytest.main([".", "-v", "-p", "no:cacheprovider"])

# Fail the cell execution if there are any test failures.
assert retcode == 0, "The pytest invocation failed. See the log for details."

Create an R notebook in the same folder as the preceding test_myfunctions.r file in your repo, and add the following contents.

In the first cell, add the following code, and then run the cell, which calls the install.packages function. This function installs testthat.

install.packages("testthat")

In the second cell, add the following code, and then run the cell. Results show which unit tests passed and failed.

library(testthat)
source("myfunctions.r")

test_dir(".", reporter = "tap")

Run the first and then second cells in the notebook from the preceding section. Results show which unit tests passed and failed.

Run each of the three cells in the notebook from the preceding section. Results show whether each unit test passed or failed.

If you no longer need the view after you run your unit tests, you can delete the view. To delete this view, you can add the following code to a new cell within one of the preceding notebooks and then run only that cell.

DROP VIEW view_diamonds;

Tip

You can view the results of your notebook runs (including unit test results) in your cluster’s driver logs. You can also specify a location for your cluster’s log delivery.

You can set up a continuous integration and continuous delivery or deployment (CI/CD) system, such as GitHub Actions, to automatically run your unit tests whenever your code changes. For an example, see the coverage of GitHub Actions in Software engineering best practices for notebooks.