メインコンテンツまでスキップ

Databricks Connect for Scala のユーザー定義関数

注記

この記事では、Databricks Runtime 14.1 以降の Databricks Connect について説明します。

Databricks Connect for Scala は、ローカル開発環境から Databricks クラスター上でユーザー定義関数 (UDF) を実行することをサポートします。

このページでは、Databricks Connect for Scala を使用してユーザー定義関数を実行する方法について説明します。

この記事の Python バージョンについては、「 Databricks Connect for Python のユーザー定義関数」を参照してください。

コンパイルされたクラスとJARをアップロードする

UDF が機能するには、コンパイルされたクラスと JAR をaddCompiledArtifacts() API を使用してクラスターにアップロードする必要があります。

注記

クライアントで使用するScalaは、DatabricksクラスターのScalaバージョンと一致する必要があります。クラスターのScalaバージョンを確認するには、Databricks Runtimeリリースノートのバージョンと互換性でクラスターのDatabricksランタイム バージョンの「システム環境」セクションを参照してください。

次の Scala プログラムは、列の値を 2 乗する単純な UDF を設定します。

Scala
import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}

object Main {
def main(args: Array[String]): Unit = {
val spark = getSession()

val squared = udf((x: Long) => x * x)

spark.range(3)
.withColumn("squared", squared(col("id")))
.select("squared")
.show()

}
}

def getSession(): SparkSession = {
if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
// On a Databricks cluster — reuse the active session
SparkSession.active
} else {
// Locally with Databricks Connect — upload local JARs and classes
DatabricksSession
.builder()
.addCompiledArtifacts(
Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI
)
.getOrCreate()
}
}
}

Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI プロジェクトのコンパイル済み出力と同じ場所を指します (たとえば、target/classes またはビルドされた JAR)。Mainだけでなく、コンパイルされたすべてのクラスが Databricks にアップロードされます。

target/scala-2.13/classes/
├── com/
│ ├── examples/
│ │ ├── Main.class
│ │ └── MyUdfs.class
│ └── utils/
│ └── Helper.class

Spark セッションがすでに初期化されている場合は、 spark.addArtifact() APIを用いてさらにコンパイルしたクラスやJARをアップロードすることができます。

注記

JAR をアップロードするときは、すべての推移的な依存関係 JAR をアップロードに含める必要があります。 APIは、推移的な依存関係の自動検出を実行しません。

サードパーティの依存関係を持つUDF

UDF で使用されているが、Databricks クラスターでは使用できない Maven 依存関係をbuild.sbtに追加した場合、次のようになります。

// In build.sbt
libraryDependencies += "org.apache.commons" % "commons-text" % "1.10.0"
Scala
// In your code
import org.apache.commons.text.StringEscapeUtils

// ClassNotFoundException thrown during UDF execution of this function on the server side
val escapeUdf = udf((text: String) => {
StringEscapeUtils.escapeHtml4(text)
})

Maven から依存関係をダウンロードするには、 spark.addArtifact()ivy://を使用します。

  1. oroライブラリをbuild.sbtファイルに追加します

    libraryDependencies ++= Seq(
    "org.apache.commons" % "commons-text" % "1.10.0" % Provided,
    "oro" % "oro" % "2.0.8" // Required for ivy:// to work
    )
  2. addArtifact() API を使用してセッションを作成した後、アーティファクトを追加します。

    Scala
    def getSession(): SparkSession = {
    if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
    SparkSession.active
    } else {
    val spark = DatabricksSession.builder()
    .addCompiledArtifacts(Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI)
    .getOrCreate()

    // Convert Maven coordinates to ivy:// format
    // From: "org.apache.commons" % "commons-text" % "1.10.0"
    // To: ivy://org.apache.commons:commons-text:1.10.0
    spark.addArtifact("ivy://org.apache.commons:commons-text:1.10.0")

    spark
    }
    }

型付き データセット API

型付きデータセットAPIs使用すると、結果のデータセットに対してmap()filter()mapPartitions()などの変換や集計を実行できます。 addCompiledArtifacts() API を使用してコンパイルされたクラスと JAR をクラスターにアップロードすると、これらにも適用されるため、コードは実行場所に応じて異なる動作をする必要があります。

  • Databricks Connectを使用した ローカル開発 : アーティファクトをリモート クラスターにアップロードします。
  • クラスター上で実行されている Databricks にデプロイされています 。クラスがすでに存在するため、何もアップロードする必要はありません。

次の Scala アプリケーションは、 map() API を使用して、result カラムの数値をプレフィックス付きの文字列に変更します。

Scala
import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}

object Main {
def main(args: Array[String]): Unit = {
val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI

val spark = DatabricksSession.builder()
.addCompiledArtifacts(sourceLocation)
.getOrCreate()

spark.range(3).map(f => s"row-$f").show()
}
}

外部JAR依存関係

クラスター上にないプライベート ライブラリまたはサードパーティ ライブラリを使用している場合:

Scala
import com.mycompany.privatelib.DataProcessor

// ClassNotFoundException thrown during UDF execution of this function on the server side
val myUdf = udf((data: String) => {
DataProcessor.process(data)
})

セッションを作成するときに、 lib/フォルダーから外部 JAR をアップロードします。

Scala
def getSession(): SparkSession = {
if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
SparkSession.active
} else {
val builder = DatabricksSession.builder()
.addCompiledArtifacts(Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI)

// Add all JARs from lib/ folder
val libFolder = new java.io.File("lib")
builder.addCompiledArtifacts(libFolder.toURI)


builder.getOrCreate()
}
}

これにより、ローカルで実行しているときに、lib/ ディレクトリ内のすべての JAR が Databricks に自動的にアップロードされます。

複数のモジュールを持つプロジェクト

マルチモジュール SBT プロジェクトでは、 getClass.getProtectionDomain.getCodeSource.getLocation.toURI現在のモジュールの場所のみを返します。UDF が他のモジュールのクラスを使用する場合は、 ClassNotFoundExceptionが返されます。

my-project/
├── module-a/ (main application)
├── module-b/ (utilities - module-a depends on this)

各モジュールのクラスのgetClassを使用して、すべての場所を取得し、個別にアップロードします。

Scala
// In module-a/src/main/scala/Main.scala
import com.company.moduleb.DataProcessor // From module-b

def getSession(): SparkSession = {
if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
SparkSession.active
} else {
// Get location using a class FROM module-a
val moduleALocation = Main.getClass
.getProtectionDomain.getCodeSource.getLocation.toURI

// Get location using a class FROM module-b
val moduleBLocation = DataProcessor.getClass
.getProtectionDomain.getCodeSource.getLocation.toURI

DatabricksSession.builder()
.addCompiledArtifacts(moduleALocation) // Upload module-a
.addCompiledArtifacts(moduleBLocation) // Upload module-b
.getOrCreate()
}
}