ノートブックの単体テスト
単体テスト を使用すると、ノートブックのコードの品質と一貫性を向上させることができます。単体テストは、関数などの自己完結型のコード単位を早期かつ頻繁にテストするアプローチです。 これにより、コードの問題をより迅速に発見し、コードに関する誤った仮定を早期に発見し、全体的なコーディング作業を効率化できます。
この記事では、関数を使用した基本的な 単体テスト の概要です。 単体テストのクラスやインターフェイス、 スタブ、 モック、 テスト ハーネスの使用などの高度な概念は、ノートブックの単体テストでもサポートされていますが、この記事では説明しません。 この記事では、統合テスト 、システムテスト 、受け入れテスト 、または パフォーマンステスト やユーザビリティテスト などの非機能テスト 方法など、他の種類のテスト方法については説明しません。
この記事では、以下のトピックを紹介します。
- 関数とその単体テストを整理する方法。
- Python、R、Scala での関数の書き方、および SQL でのユーザー定義関数の書き方で、単体テストに適した設計になっています。
- Python、R、Scala、SQL ノートブックからこれらの関数を呼び出す方法。
- Python、R、Scala の一般的なテスト フレームワークである pytest for Python、 testthat 、および Scala for Scala を使用して単体テストを記述する方法。 また、ユニット テストの SQL ユーザー定義関数 (SQL UDF) の SQL の記述方法も説明します。
- これらの単体テストを Python、R、Scala、SQL ノートブックから実行する方法。
Databricks では、単体テストをノートブックに記述して実行することをお勧めします。 Webターミナルで一部のコマンドを実行できますが、WebターミナルにはSparkのサポートがないなど、より多くの制限があります。 「Databricks Webターミナルの実行 シェル コマンド」を参照してください。
関数と単体テストの整理
ノートブックを使用して関数とその単体テストを整理するための一般的な方法がいくつかあります。 それぞれのアプローチには、それぞれの利点と課題があります。
Python、R、Scala ノートブックの場合、一般的なアプローチは次のとおりです。
-
- 利点: これらの関数は、ノートブックの内外で呼び出すことができます。 テスト フレームワークは、ノートブックの外部でテストを実行するように適切に設計されています。
- 課題: このアプローチは Scala ノートブックではサポートされていません。 このアプローチでは、追跡および保守するファイルの数も増加します。
-
関数を 1 つのノートブックに格納し、その単体テストを別のノートブックに格納します。
- 利点: これらの関数は、ノートブック間で再利用しやすくなります。
- 課題: 追跡および保守するノートブックの数が増加します。 これらの機能は、ノートブックの外部では使用できません。 また、これらの機能は、ノートブックの外部でテストするのがより困難になる場合があります。
-
- 利点: 関数とその単体テストは 1 つのノートブックに保存されるため、追跡と保守が容易になります。
- 課題: これらの関数は、ノートブック間で再利用するのがより難しくなる可能性があります。 これらの機能は、ノートブックの外部では使用できません。 また、これらの機能は、ノートブックの外部でテストするのがより困難になる場合があります。
Python ノートブックと R ノートブックの場合、Databricks では関数とその単体テストをノートブックの外部に格納することをお勧めします。 Scalaノートブックの場合、Databricks では、関数を 1 つのノートブックに含め、その単体テストを別のノートブックに含めることをお勧めします。
SQL ノートブックの場合、Databricks では、関数を SQL ユーザー定義関数 (SQL UDF) としてスキーマ (データベースとも呼ばれます) に格納することをお勧めします。 その後、これらの SQL UDF とその単体テストを SQL ノートブックから呼び出すことができます。
書き込み関数
このセクションでは、以下を決定する関数の簡単な例のセットについて説明します。
- データベースにテーブルが存在するかどうか。
- テーブルに列が存在するかどうか。
- その列内の値に対して、列に存在する行数。
これらの関数は単純に作られているので、関数そのものに集中するよりも、この記事のユニットテストの詳細に集中できます。
最適な単体テスト結果を得るには、関数は 1 つの予測可能な結果を返し、1 つのデータ型である必要があります。 たとえば、何かが存在するかどうかを確認するには、関数は true または false のブール値を返す必要があります。 存在する行数を返すには、関数は負でない整数を返す必要があります。 最初の例では、何かが存在しない場合は false を返し、存在する場合はそれ自体を返すべきではありません。 同様に、2 番目の例では、存在する行の数を返さず、行が存在しない場合は false を返さないでください。
これらの関数は、Python、R、Scala、または SQL で次のように、既存の Databricks ワークスペースに追加できます。
- Python
- R
- Scala
- SQL
The following code assumes you have Set up Databricks Git folders (Repos), added a repo, and have the repo open in your Databricks workspace.
Create a file named myfunctions.py
within the repo, and add the following contents to the file. Other examples in this article expect this file to be named myfunctions.py
. You can use different names for your own files.
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Because this file is not a Databricks notebook, you
# must create a Spark session. Databricks notebooks
# create a Spark session for you by default.
spark = SparkSession.builder \
.appName('integrity-tests') \
.getOrCreate()
# Does the specified table exist in the specified database?
def tableExists(tableName, dbName):
return spark.catalog.tableExists(f"{dbName}.{tableName}")
# Does the specified column exist in the given DataFrame?
def columnExists(dataFrame, columnName):
if columnName in dataFrame.columns:
return True
else:
return False
# How many rows are there for the specified value in the specified column
# in the given DataFrame?
def numRowsInColumnForValue(dataFrame, columnName, columnValue):
df = dataFrame.filter(col(columnName) == columnValue)
return df.count()
The following code assumes you have Set up Databricks Git folders (Repos), added a repo, and have the repo open in your Databricks workspace.
Create a file named myfunctions.r
within the repo, and add the following contents to the file. Other examples in this article expect this file to be named myfunctions.r
. You can use different names for your own files.
library(SparkR)
# Does the specified table exist in the specified database?
table_exists <- function(table_name, db_name) {
tableExists(paste(db_name, ".", table_name, sep = ""))
}
# Does the specified column exist in the given DataFrame?
column_exists <- function(dataframe, column_name) {
column_name %in% colnames(dataframe)
}
# How many rows are there for the specified value in the specified column
# in the given DataFrame?
num_rows_in_column_for_value <- function(dataframe, column_name, column_value) {
df = filter(dataframe, dataframe[[column_name]] == column_value)
count(df)
}
Create a Scala notebook named myfunctions
with the following contents. Other examples in this article expect this notebook to be named myfunctions
. You can use different names for your own notebooks.
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
// Does the specified table exist in the specified database?
def tableExists(tableName: String, dbName: String) : Boolean = {
return spark.catalog.tableExists(dbName + "." + tableName)
}
// Does the specified column exist in the given DataFrame?
def columnExists(dataFrame: DataFrame, columnName: String) : Boolean = {
val nameOfColumn = null
for(nameOfColumn <- dataFrame.columns) {
if (nameOfColumn == columnName) {
return true
}
}
return false
}
// How many rows are there for the specified value in the specified column
// in the given DataFrame?
def numRowsInColumnForValue(dataFrame: DataFrame, columnName: String, columnValue: String) : Long = {
val df = dataFrame.filter(col(columnName) === columnValue)
return df.count()
}
The following code assumes you have the third-party sample dataset diamonds within a schema named default
within a catalog named main
that is accessible from your Databricks workspace. If the catalog or schema that you want to use has a different name, then change one or both of the following USE
statements to match.
Create a SQL notebook and add the following contents to this new notebook. Then attach the notebook to a cluster and run the notebook to add the following SQL UDFs to the specified catalog and schema.
The SQL UDFs table_exists
and column_exists
work only with Unity Catalog. SQL UDF support for Unity Catalog is in Public Preview.
USE CATALOG main;
USE SCHEMA default;
CREATE OR REPLACE FUNCTION table_exists(catalog_name STRING,
db_name STRING,
table_name STRING)
RETURNS BOOLEAN
RETURN if(
(SELECT count(*) FROM system.information_schema.tables
WHERE table_catalog = table_exists.catalog_name
AND table_schema = table_exists.db_name
AND table_name = table_exists.table_name) > 0,
true,
false
);
CREATE OR REPLACE FUNCTION column_exists(catalog_name STRING,
db_name STRING,
table_name STRING,
column_name STRING)
RETURNS BOOLEAN
RETURN if(
(SELECT count(*) FROM system.information_schema.columns
WHERE table_catalog = column_exists.catalog_name
AND table_schema = column_exists.db_name
AND table_name = column_exists.table_name
AND column_name = column_exists.column_name) > 0,
true,
false
);
CREATE OR REPLACE FUNCTION num_rows_for_clarity_in_diamonds(clarity_value STRING)
RETURNS BIGINT
RETURN SELECT count(*)
FROM main.default.diamonds
WHERE clarity = clarity_value
関数を呼び出す
このセクションでは、上記の関数を呼び出すコードについて説明します。 たとえば、これらの関数を使用して、指定した列内に指定した値が存在するテーブル内の行数をカウントできます。 ただし、先に進む前に、テーブルが実際に存在するかどうか、および列がそのテーブルに実際に存在するかどうかを確認する必要があります。 次のコードでは、これらの条件を確認します。
前のセクションの関数を Databricks ワークスペースに追加した場合は、次のようにワークスペースからこれらの関数を呼び出すことができます。
- Python
- R
- Scala
- SQL
Create a Python notebook in the same folder as the preceding myfunctions.py
file in your repo, and add the following contents to the notebook. Change the variable values for the table name, the schema (database) name, the column name, and the column value as needed. Then attach the notebook to a cluster and run the notebook to see the results.
from myfunctions import *
tableName = "diamonds"
dbName = "default"
columnName = "clarity"
columnValue = "VVS2"
# If the table exists in the specified database...
if tableExists(tableName, dbName):
df = spark.sql(f"SELECT * FROM {dbName}.{tableName}")
# And the specified column exists in that table...
if columnExists(df, columnName):
# Then report the number of rows for the specified value in that column.
numRows = numRowsInColumnForValue(df, columnName, columnValue)
print(f"There are {numRows} rows in '{tableName}' where '{columnName}' equals '{columnValue}'.")
else:
print(f"Column '{columnName}' does not exist in table '{tableName}' in schema (database) '{dbName}'.")
else:
print(f"Table '{tableName}' does not exist in schema (database) '{dbName}'.")
Create an R notebook in the same folder as the preceding myfunctions.r
file in your repo, and add the following contents to the notebook. Change the variable values for the table name, the schema (database) name, the column name, and the column value as needed. Then attach the notebook to a cluster and run the notebook to see the results.
library(SparkR)
source("myfunctions.r")
table_name <- "diamonds"
db_name <- "default"
column_name <- "clarity"
column_value <- "VVS2"
# If the table exists in the specified database...
if (table_exists(table_name, db_name)) {
df = sql(paste("SELECT * FROM ", db_name, ".", table_name, sep = ""))
# And the specified column exists in that table...
if (column_exists(df, column_name)) {
# Then report the number of rows for the specified value in that column.
num_rows = num_rows_in_column_for_value(df, column_name, column_value)
print(paste("There are ", num_rows, " rows in table '", table_name, "' where '", column_name, "' equals '", column_value, "'.", sep = ""))
} else {
print(paste("Column '", column_name, "' does not exist in table '", table_name, "' in schema (database) '", db_name, "'.", sep = ""))
}
} else {
print(paste("Table '", table_name, "' does not exist in schema (database) '", db_name, "'.", sep = ""))
}
Create another Scala notebook in the same folder as the preceding myfunctions
Scala notebook, and add the following contents to this new notebook.
In this new notebook’s first cell, add the following code, which calls the %run magic. This magic makes the contents of the myfunctions
notebook available to your new notebook.
%run ./myfunctions
In this new notebook’s second cell, add the following code. Change the variable values for the table name, the schema (database) name, the column name, and the column value as needed. Then attach the notebook to a cluster and run the notebook to see the results.
val tableName = "diamonds"
val dbName = "default"
val columnName = "clarity"
val columnValue = "VVS2"
// If the table exists in the specified database...
if (tableExists(tableName, dbName)) {
val df = spark.sql("SELECT * FROM " + dbName + "." + tableName)
// And the specified column exists in that table...
if (columnExists(df, columnName)) {
// Then report the number of rows for the specified value in that column.
val numRows = numRowsInColumnForValue(df, columnName, columnValue)
println("There are " + numRows + " rows in '" + tableName + "' where '" + columnName + "' equals '" + columnValue + "'.")
} else {
println("Column '" + columnName + "' does not exist in table '" + tableName + "' in database '" + dbName + "'.")
}
} else {
println("Table '" + tableName + "' does not exist in database '" + dbName + "'.")
}
Add the following code to a new cell in the preceding notebook or to a cell in a separate notebook. Change the schema or catalog names if necessary to match yours, and then run this cell to see the results.
SELECT CASE
-- If the table exists in the specified catalog and schema...
WHEN
table_exists("main", "default", "diamonds")
THEN
-- And the specified column exists in that table...
(SELECT CASE
WHEN
column_exists("main", "default", "diamonds", "clarity")
THEN
-- Then report the number of rows for the specified value in that column.
printf("There are %d rows in table 'main.default.diamonds' where 'clarity' equals 'VVS2'.",
num_rows_for_clarity_in_diamonds("VVS2"))
ELSE
printf("Column 'clarity' does not exist in table 'main.default.diamonds'.")
END)
ELSE
printf("Table 'main.default.diamonds' does not exist.")
END
単体テストの記述
このセクションでは、この記事の冒頭で説明した各関数をテストするコードについて説明します。 将来、関数に変更を加えた場合は、単体テストを使用して、それらの関数が期待どおりに動作するかどうかを判断できます。
この記事の冒頭で関数を Databricks ワークスペースに追加した場合は、次のようにして、これらの関数の単体テストをワークスペースに追加できます。
- Python
- R
- Scala
- SQL
Create another file named test_myfunctions.py
in the same folder as the preceding myfunctions.py
file in your repo, and add the following contents to the file. By default, pytest
looks for .py
files whose names start with test_
(or end with _test
) to test. Similarly, by default, pytest
looks inside of these files for functions whose names start with test_
to test.
In general, it is a best practice to not run unit tests against functions that work with data in production. This is especially important for functions that add, remove, or otherwise change data. To protect your production data from being compromised by your unit tests in unexpected ways, you should run unit tests against non-production data. One common approach is to create fake data that is as close as possible to the production data. The following code example creates fake data for the unit tests to run against.
import pytest
import pyspark
from myfunctions import *
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, StringType
tableName = "diamonds"
dbName = "default"
columnName = "clarity"
columnValue = "SI2"
# Because this file is not a Databricks notebook, you
# must create a Spark session. Databricks notebooks
# create a Spark session for you by default.
spark = SparkSession.builder \
.appName('integrity-tests') \
.getOrCreate()
# Create fake data for the unit tests to run against.
# In general, it is a best practice to not run unit tests
# against functions that work with data in production.
schema = StructType([ \
StructField("_c0", IntegerType(), True), \
StructField("carat", FloatType(), True), \
StructField("cut", StringType(), True), \
StructField("color", StringType(), True), \
StructField("clarity", StringType(), True), \
StructField("depth", FloatType(), True), \
StructField("table", IntegerType(), True), \
StructField("price", IntegerType(), True), \
StructField("x", FloatType(), True), \
StructField("y", FloatType(), True), \
StructField("z", FloatType(), True), \
])
data = [ (1, 0.23, "Ideal", "E", "SI2", 61.5, 55, 326, 3.95, 3.98, 2.43 ), \
(2, 0.21, "Premium", "E", "SI1", 59.8, 61, 326, 3.89, 3.84, 2.31 ) ]
df = spark.createDataFrame(data, schema)
# Does the table exist?
def test_tableExists():
assert tableExists(tableName, dbName) is True
# Does the column exist?
def test_columnExists():
assert columnExists(df, columnName) is True
# Is there at least one row for the value in the specified column?
def test_numRowsInColumnForValue():
assert numRowsInColumnForValue(df, columnName, columnValue) > 0
Create another file named test_myfunctions.r
in the same folder as the preceding myfunctions.r
file in your repo, and add the following contents to the file. By default, testthat
looks for .r
files whose names start with test
to test.
In general, it is a best practice to not run unit tests against functions that work with data in production. This is especially important for functions that add, remove, or otherwise change data. To protect your production data from being compromised by your unit tests in unexpected ways, you should run unit tests against non-production data. One common approach is to create fake data that is as close as possible to the production data. The following code example creates fake data for the unit tests to run against.
library(testthat)
source("myfunctions.r")
table_name <- "diamonds"
db_name <- "default"
column_name <- "clarity"
column_value <- "SI2"
# Create fake data for the unit tests to run against.
# In general, it is a best practice to not run unit tests
# against functions that work with data in production.
schema <- structType(
structField("_c0", "integer"),
structField("carat", "float"),
structField("cut", "string"),
structField("color", "string"),
structField("clarity", "string"),
structField("depth", "float"),
structField("table", "integer"),
structField("price", "integer"),
structField("x", "float"),
structField("y", "float"),
structField("z", "float"))
data <- list(list(as.integer(1), 0.23, "Ideal", "E", "SI2", 61.5, as.integer(55), as.integer(326), 3.95, 3.98, 2.43),
list(as.integer(2), 0.21, "Premium", "E", "SI1", 59.8, as.integer(61), as.integer(326), 3.89, 3.84, 2.31))
df <- createDataFrame(data, schema)
# Does the table exist?
test_that ("The table exists.", {
expect_true(table_exists(table_name, db_name))
})
# Does the column exist?
test_that ("The column exists in the table.", {
expect_true(column_exists(df, column_name))
})
# Is there at least one row for the value in the specified column?
test_that ("There is at least one row in the query result.", {
expect_true(num_rows_in_column_for_value(df, column_name, column_value) > 0)
})
Create another Scala notebook in the same folder as the preceding myfunctions
Scala notebook, and add the following contents to this new notebook.
In the new notebook’s first cell, add the following code, which calls the %run
magic. This magic makes the contents of the myfunctions
notebook available to your new notebook.
%run ./myfunctions
In the second cell, add the following code. This code defines your unit tests and specifies how to run them.
In general, it is a best practice to not run unit tests against functions that work with data in production. This is especially important for functions that add, remove, or otherwise change data. To protect your production data from being compromised by your unit tests in unexpected ways, you should run unit tests against non-production data. One common approach is to create fake data that is as close as possible to the production data. The following code example creates fake data for the unit tests to run against.
import org.scalatest._
import org.apache.spark.sql.types.{StructType, StructField, IntegerType, FloatType, StringType}
import scala.collection.JavaConverters._
class DataTests extends AsyncFunSuite {
val tableName = "diamonds"
val dbName = "default"
val columnName = "clarity"
val columnValue = "SI2"
// Create fake data for the unit tests to run against.
// In general, it is a best practice to not run unit tests
// against functions that work with data in production.
val schema = StructType(Array(
StructField("_c0", IntegerType),
StructField("carat", FloatType),
StructField("cut", StringType),
StructField("color", StringType),
StructField("clarity", StringType),
StructField("depth", FloatType),
StructField("table", IntegerType),
StructField("price", IntegerType),
StructField("x", FloatType),
StructField("y", FloatType),
StructField("z", FloatType)
))
val data = Seq(
Row(1, 0.23, "Ideal", "E", "SI2", 61.5, 55, 326, 3.95, 3.98, 2.43),
Row(2, 0.21, "Premium", "E", "SI1", 59.8, 61, 326, 3.89, 3.84, 2.31)
).asJava
val df = spark.createDataFrame(data, schema)
// Does the table exist?
test("The table exists") {
assert(tableExists(tableName, dbName) == true)
}
// Does the column exist?
test("The column exists") {
assert(columnExists(df, columnName) == true)
}
// Is there at least one row for the value in the specified column?
test("There is at least one matching row") {
assert(numRowsInColumnForValue(df, columnName, columnValue) > 0)
}
}
nocolor.nodurations.nostacks.stats.run(new DataTests)
This code example uses the FunSuite
style of testing in ScalaTest. For other available testing styles, see Selecting testing styles for your project.
Before you add unit tests, you should be aware that in general, it is a best practice to not run unit tests against functions that work with data in production. This is especially important for functions that add, remove, or otherwise change data. To protect your production data from being compromised by your unit tests in unexpected ways, you should run unit tests against non-production data. One common approach is to run unit tests against views instead of tables.
To create a view, you can call the CREATE VIEW command from a new cell in either the preceding notebook or a separate notebook. The following example assumes that you have an existing table named diamonds
within a schema (database) named default
within a catalog named main
. Change these names to match your own as needed, and then run only that cell.
USE CATALOG main;
USE SCHEMA default;
CREATE VIEW view_diamonds AS
SELECT * FROM diamonds;
After you create the view, add each of the following SELECT
statements to its own new cell in the preceding notebook or to its own new cell in a separate notebook. Change the names to match your own as needed.
SELECT if(table_exists("main", "default", "view_diamonds"),
printf("PASS: The table 'main.default.view_diamonds' exists."),
printf("FAIL: The table 'main.default.view_diamonds' does not exist."));
SELECT if(column_exists("main", "default", "view_diamonds", "clarity"),
printf("PASS: The column 'clarity' exists in the table 'main.default.view_diamonds'."),
printf("FAIL: The column 'clarity' does not exists in the table 'main.default.view_diamonds'."));
SELECT if(num_rows_for_clarity_in_diamonds("VVS2") > 0,
printf("PASS: The table 'main.default.view_diamonds' has at least one row where the column 'clarity' equals 'VVS2'."),
printf("FAIL: The table 'main.default.view_diamonds' does not have at least one row where the column 'clarity' equals 'VVS2'."));
単体テストの実行
このセクションでは、前のセクションでコーディングした単体テストを実行する方法について説明します。 単体テストを実行すると、成功した単体テストと失敗した単体テストの結果が表示されます。
前のセクションの単体テストを Databricks ワークスペースに追加した場合は、ワークスペースからこれらの単体テストを実行できます。 これらの単体テストは、 手動で 実行することも、 スケジュールに従って実行することもできます。
- Python
- R
- Scala
- SQL
Create a Python notebook in the same folder as the preceding test_myfunctions.py
file in your repo, and add the following contents.
In the new notebook’s first cell, add the following code, and then run the cell, which calls the %pip
magic. This magic installs pytest
.
%pip install pytest
In the second cell, add the following code and then run the cell. Results show which unit tests passed and failed.
import pytest
import sys
# Skip writing pyc files on a readonly filesystem.
sys.dont_write_bytecode = True
# Run pytest.
retcode = pytest.main([".", "-v", "-p", "no:cacheprovider"])
# Fail the cell execution if there are any test failures.
assert retcode == 0, "The pytest invocation failed. See the log for details."
Create an R notebook in the same folder as the preceding test_myfunctions.r
file in your repo, and add the following contents.
In the first cell, add the following code, and then run the cell, which calls the install.packages
function. This function installs testthat
.
install.packages("testthat")
In the second cell, add the following code, and then run the cell. Results show which unit tests passed and failed.
library(testthat)
source("myfunctions.r")
test_dir(".", reporter = "tap")
Run the first and then second cells in the notebook from the preceding section. Results show which unit tests passed and failed.
Run each of the three cells in the notebook from the preceding section. Results show whether each unit test passed or failed.
If you no longer need the view after you run your unit tests, you can delete the view. To delete this view, you can add the following code to a new cell within one of the preceding notebooks and then run only that cell.
DROP VIEW view_diamonds;
ノートブックの実行結果 (単体テストの結果を含む) は、クラスターの ドライバー ログで表示できます。 クラスターの ログ配信の場所を指定することもできます。
などの継続的インテグレーションと継続的デリバリーまたはデプロイメントCI/CD ()GitHub Actions システムを設定して、コードが変更されるたびにユニット・テストを自動的に実行できます。例については、「 ノートブックのソフトウェア エンジニアリングのベスト プラクティス」の GitHub Actions のカバレッジを参照してください。