500件以上のラベルでドキュメントを分類します
ai_classify 1回の呼び出しで最大500ラベルを受け付けます。より大規模なタクソノミーの場合は、文書ごとに埋め込み類似度を使用してラベルを事前にフィルタリングし、次に、トップK候補のショートリストに対してai_classifyを呼び出します。このチュートリアルでは、最適なK (精度を維持する最小の候補数) を見つける方法を説明します。
vector_cosine_similarity Databricks Runtime 18.1 以降でのみ利用可能です。
始める前に
- Unity Catalog が有効なワークスペースは、
ai_classifyにアクセスできます (提供状況をご覧ください)。 - Databricks Runtime 18.1以上(
vector_cosine_similarityに必要です)。 - 分類対象ドキュメントのDeltaテーブルです。
- キーカラムとオプションの説明カラムを備えた、ラベルの Delta テーブル。
- グラウンドトゥルースラベル付きの小規模な評価セット、または作成する機能 (ステップ4のオプションB)。
0. 構成。
テーブル名、列名、および埋め込みモデルを設定してください。ヘルパー関数「top_k_labels_json」は、各ドキュメントのトップK候補に対して、「ai_classify」に渡されるJSON式を構築します。
# -- Your tables --
DOCS_TABLE = "path.to.your_docs_table" # table of documents to classify
DOCS_TEXT_COL = "document" # column with text to classify
DOCS_ID_COL = None # unique ID column; set to None to auto-generate via md5
LABELS_TABLE = "path.to.your_labels_table" # table of labels
LABELS_KEY_COL = "label" # column with label value
LABELS_DESC_COL = "description" # description column; set to None if labels have no descriptions
# -- Embedding model --
EMBEDDING_MODEL = "databricks-qwen3-embedding-0-6b" # compact model, good default for English text
# -- K values to sweep --
K_VALUES = [10, 20, 50, 100, 200, 500]
# -- Eval set size (if you need to create one) --
EVAL_SAMPLE_SIZE = 100 # docs to sample for manual labeling
doc_id_expr = DOCS_ID_COL if DOCS_ID_COL else f"md5({DOCS_TEXT_COL})"
label_embed_text = (
f"concat({LABELS_KEY_COL}, ': ', {LABELS_DESC_COL})"
if LABELS_DESC_COL
else LABELS_KEY_COL
)
def top_k_labels_json(k_val, prefix=""):
"""Build a JSON expression for top-K labels using max_by aggregate."""
col_prefix = f"{prefix}." if prefix else ""
if LABELS_DESC_COL:
return f"to_json(map_from_entries(max_by(struct({col_prefix}{LABELS_KEY_COL}, {col_prefix}{LABELS_DESC_COL}), {col_prefix}score, {k_val})))"
else:
return f"to_json(max_by({col_prefix}{LABELS_KEY_COL}, {col_prefix}score, {k_val}))"
print(f"Docs table: {DOCS_TABLE} (text: {DOCS_TEXT_COL}, id: {doc_id_expr})")
print(f"Labels table: {LABELS_TABLE} (key: {LABELS_KEY_COL}, desc: {LABELS_DESC_COL})")
print(f"Embed text: {label_embed_text}")
print(f"K sweep: {K_VALUES}")
1. ラベルを埋め込む
一度実行してください。分類が変更されたときのみ再実行してください。
spark.sql(f"""
CREATE OR REPLACE TABLE label_embeddings AS
SELECT
{LABELS_KEY_COL},
{f'{LABELS_DESC_COL},' if LABELS_DESC_COL else ''}
cast(
ai_query('{EMBEDDING_MODEL}', {label_embed_text}) AS ARRAY<FLOAT>
) AS embedding
FROM {LABELS_TABLE}
""")
label_count = spark.sql("SELECT count(*) AS n FROM label_embeddings").first()["n"]
print(f"Embedded {label_count} labels")
2. ドキュメントを埋め込む
spark.sql(f"""
CREATE OR REPLACE TABLE doc_embeddings AS
SELECT
{doc_id_expr} AS id,
{DOCS_TEXT_COL},
cast(
ai_query('{EMBEDDING_MODEL}', {DOCS_TEXT_COL}) AS ARRAY<FLOAT>
) AS embedding
FROM {DOCS_TABLE}
""")
doc_count = spark.sql("SELECT count(*) AS n FROM doc_embeddings").first()["n"]
print(f"Embedded {doc_count} documents")
3. コサイン類似度によるラベルのスコア付け
ラベル付きドキュメントをクロス結合し、すべてのペアに対してコサイン類似度をコンピュートします。これにより N_docs × N_labels 行が生成されます。非常に大規模なドキュメントテーブルについては、doc_embeddingsをパーティション分割し、各パーティションでこれを実行することを検討してください。ドキュメントあたりのTop-Kラベルは後ほどmax_by(struct, score, K)を用いて集約されます。
spark.sql(f"""
CREATE OR REPLACE TABLE scored_labels AS
SELECT
d.id,
l.{LABELS_KEY_COL},
{f'l.{LABELS_DESC_COL},' if LABELS_DESC_COL else ''}
vector_cosine_similarity(d.embedding, l.embedding) AS score
FROM doc_embeddings d
CROSS JOIN label_embeddings l
""")
row_count = spark.sql("SELECT count(*) AS n FROM scored_labels").first()["n"]
print(f"Scored labels: {row_count:,} rows ({doc_count} docs × {label_count} labels)")
4. グラウンドトゥルース評価セットを準備する
Kチューニングは、正解ラベルが既知の少数のドキュメントセットを必要とします。既存の評価テーブルがある場合、次のセルでEVAL_TABLEを設定し、サンプリングセルをスキップします。そうでない場合、2番目のセルには、手動でラベル付けして再インポートできるドキュメントのサンプルが表示されます。
# Option A: point to your existing eval table
# Must have columns: id (matching doc_embeddings.id) and ground_truth_label
EVAL_TABLE = dbutils.widgets.get("eval_table") # read from notebook widget
if EVAL_TABLE:
eval_df = spark.table(EVAL_TABLE)
print(f"Loaded {eval_df.count()} eval examples from {EVAL_TABLE}")
else:
print("No eval table set — run the next cell to sample documents for labeling.")
# Option B: sample documents for manual labeling
if not EVAL_TABLE:
sample_df = spark.sql(f"""
SELECT id, {DOCS_TEXT_COL}
FROM doc_embeddings
ORDER BY rand()
LIMIT {EVAL_SAMPLE_SIZE}
""")
sample_df.display()
print(f"\nSampled {EVAL_SAMPLE_SIZE} documents.")
print("Next steps:")
print(" 1. Export these rows (copy the table above or save to CSV)")
print(" 2. Add a 'ground_truth_label' column and fill in the correct label for each doc")
print(" 3. Re-import as a Delta table and set EVAL_TABLE above")
print(" 4. Re-run cell 4 (Option A) to load it")
5番目。再現率@K の測定
再現率@Kは、グラウンドトゥルースラベルが上位K個の埋め込み候補に含まれているかどうかを確認します。これは 検索のみのメトリクス であり、ai_classify を呼び出さずにすぐに実行されます。
特定のKにおいて再現率が低い場合、分類が実行される前から正しいラベルが候補セットから除外されていたため、ai_classify が正しい回答を返すことはできません。
assert EVAL_TABLE, "Set EVAL_TABLE in cell 4 before running K-tuning."
spark.sql(f"CREATE OR REPLACE TEMP VIEW eval_set AS SELECT * FROM {EVAL_TABLE}")
recall_results = []
for k in K_VALUES:
row = spark.sql(f"""
SELECT
{k} AS k,
count(*) AS eval_size,
sum(CASE WHEN hit THEN 1 ELSE 0 END) AS hits,
round(sum(CASE WHEN hit THEN 1 ELSE 0 END) / count(*), 4) AS recall_at_k
FROM (
SELECT
e.id,
array_contains(
max_by(r.{LABELS_KEY_COL}, r.score, {k}),
e.ground_truth_label
) AS hit
FROM eval_set e
JOIN scored_labels r ON r.id = e.id
GROUP BY e.id, e.ground_truth_label
)
""").first()
recall_results.append(row.asDict())
print(f" K={k:>4d} → Recall@K = {row['recall_at_k']:.2%} ({row['hits']}/{row['eval_size']})")
recall_df = spark.createDataFrame(recall_results)
recall_df.display()
6. エンドツーエンドの精度を測定する
各Kについて、評価ドキュメントごとに上位K個のラベルセットを構築し、ai_classifyを実行し、グラウンドトゥルースと比較します。
このステップはai_classifyを呼び出し、再現率チェックよりも費用がかかります。再現率がすでに妥当なK値から始めてください。
accuracy_results = []
for k in K_VALUES:
# Build top-K labels per eval doc using max_by aggregate
spark.sql(f"""
CREATE OR REPLACE TEMP VIEW eval_top_labels AS
SELECT
r.id,
{top_k_labels_json(k, 'r')} AS labels
FROM scored_labels r
JOIN eval_set e ON e.id = r.id
GROUP BY r.id
""")
# Materialize ai_classify first (returns VARIANT, and is non-deterministic so can't go inside aggregate)
spark.sql(f"""
CREATE OR REPLACE TEMP VIEW eval_predictions AS
SELECT
e.id,
e.ground_truth_label,
get_json_object(cast(ai_classify(d.{DOCS_TEXT_COL}, t.labels, map('version', '2.0')) as string), '$.response[0]') AS predicted_label
FROM eval_set e
JOIN doc_embeddings d ON d.id = e.id
JOIN eval_top_labels t ON t.id = e.id
""")
row = spark.sql(f"""
SELECT
{k} AS k,
count(*) AS eval_size,
sum(CASE WHEN predicted_label = ground_truth_label THEN 1 ELSE 0 END) AS correct,
round(
sum(CASE WHEN predicted_label = ground_truth_label THEN 1 ELSE 0 END) / count(*),
4
) AS accuracy
FROM eval_predictions
""").first()
accuracy_results.append(row.asDict())
print(f" K={k:>4d} → Accuracy = {row['accuracy']:.2%} ({row['correct']}/{row['eval_size']})")
accuracy_df = spark.createDataFrame(accuracy_results)
accuracy_df.display()
7. 結果を比較し、K を選択します。
以下のグラフは、Recall@Kとエンドツーエンドの精度を並べて示したものです。精度が向上しなくなる**最小のK**を選択してください。Kが大きくなると、分類速度が低下し、品質の向上は見込めません。
import pandas as pd
import matplotlib.pyplot as plt
recall_pd = pd.DataFrame(recall_results)
accuracy_pd = pd.DataFrame(accuracy_results)
combined = recall_pd.merge(accuracy_pd, on="k", suffixes=("_recall", "_acc"))
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(combined["k"], combined["recall_at_k"], "o-", label="Recall@K", linewidth=2)
ax.plot(combined["k"], combined["accuracy"], "s--", label="End-to-end accuracy", linewidth=2)
ax.set_xlabel("K (candidate labels per document)")
ax.set_ylabel("Score")
ax.set_title("K-Tuning: Recall@K vs End-to-End Accuracy")
ax.set_ylim(0, 1.05)
ax.set_xticks(combined["k"])
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\nFull results:")
print(combined[["k", "recall_at_k", "accuracy"]].to_string(index=False))
# Pick your K based on the chart above
CHOSEN_K = 50 # <-- edit this
chosen_row = combined[combined["k"] == CHOSEN_K].iloc[0]
print(f"Chosen K = {CHOSEN_K}")
print(f" Recall@K: {chosen_row['recall_at_k']:.2%}")
print(f" End-to-end accuracy: {chosen_row['accuracy']:.2%}")
8. 選択したKで完全な分類を実行します。
選択したKをドキュメントテーブル全体に適用してください。
spark.sql(f"""
CREATE OR REPLACE TABLE top_labels_per_doc AS
SELECT
id,
{top_k_labels_json(CHOSEN_K, 'r')} AS labels
FROM scored_labels r
GROUP BY id
""")
print(f"Built top-{CHOSEN_K} label sets for all documents")
result_df = spark.sql(f"""
SELECT
c.{DOCS_TEXT_COL},
cast(ai_classify(c.{DOCS_TEXT_COL}, t.labels, map('version', '2.0')) as string) AS classification
FROM {DOCS_TABLE} c
JOIN top_labels_per_doc t ON t.id = {doc_id_expr.replace(DOCS_TEXT_COL, f'c.{DOCS_TEXT_COL}')}
""")
result_df.display()
# Optionally save results
# result_df.write.mode("overwrite").saveAsTable("my_catalog.my_schema.classification_results")