Skip to main content

Classify documents with 500+ labels

ai_classify accepts up to 500 labels per call. For larger taxonomies, pre-filter labels per document using embedding similarity, then call ai_classify on the shortlist of top-K candidates. This tutorial shows you how to find the optimal K — the smallest number of candidates that preserves accuracy.

note

vector_cosine_similarity is only available on Databricks Runtime 18.1 and above.

Before you begin

  • A Unity Catalog-enabled workspace with access to ai_classify (see availability).
  • Databricks Runtime 18.1+ (required for vector_cosine_similarity).
  • A Delta table of documents to classify.
  • A Delta table of labels with a key column and an optional description column.
  • A small evaluation set with ground-truth labels, or the ability to create one (Option B in step 4).

0. Configuration

Set your table names, column names, and embedding model. The helper function top_k_labels_json builds the JSON expression passed to ai_classify for each document's top-K candidates.

Python
# -- Your tables --
DOCS_TABLE = "path.to.your_docs_table" # table of documents to classify
DOCS_TEXT_COL = "document" # column with text to classify
DOCS_ID_COL = None # unique ID column; set to None to auto-generate via md5

LABELS_TABLE = "path.to.your_labels_table" # table of labels
LABELS_KEY_COL = "label" # column with label value
LABELS_DESC_COL = "description" # description column; set to None if labels have no descriptions

# -- Embedding model --
EMBEDDING_MODEL = "databricks-qwen3-embedding-0-6b" # compact model, good default for English text

# -- K values to sweep --
K_VALUES = [10, 20, 50, 100, 200, 500]

# -- Eval set size (if you need to create one) --
EVAL_SAMPLE_SIZE = 100 # docs to sample for manual labeling
Python
doc_id_expr = DOCS_ID_COL if DOCS_ID_COL else f"md5({DOCS_TEXT_COL})"
label_embed_text = (
f"concat({LABELS_KEY_COL}, ': ', {LABELS_DESC_COL})"
if LABELS_DESC_COL
else LABELS_KEY_COL
)
def top_k_labels_json(k_val, prefix=""):
"""Build a JSON expression for top-K labels using max_by aggregate."""
col_prefix = f"{prefix}." if prefix else ""
if LABELS_DESC_COL:
return f"to_json(map_from_entries(max_by(struct({col_prefix}{LABELS_KEY_COL}, {col_prefix}{LABELS_DESC_COL}), {col_prefix}score, {k_val})))"
else:
return f"to_json(max_by({col_prefix}{LABELS_KEY_COL}, {col_prefix}score, {k_val}))"

print(f"Docs table: {DOCS_TABLE} (text: {DOCS_TEXT_COL}, id: {doc_id_expr})")
print(f"Labels table: {LABELS_TABLE} (key: {LABELS_KEY_COL}, desc: {LABELS_DESC_COL})")
print(f"Embed text: {label_embed_text}")
print(f"K sweep: {K_VALUES}")

1. Embed the labels

Run this once. Re-run only when the taxonomy changes.

Python
spark.sql(f"""
CREATE OR REPLACE TABLE label_embeddings AS
SELECT
{LABELS_KEY_COL},
{f'{LABELS_DESC_COL},' if LABELS_DESC_COL else ''}
cast(
ai_query('{EMBEDDING_MODEL}', {label_embed_text}) AS ARRAY<FLOAT>
) AS embedding
FROM {LABELS_TABLE}
""")

label_count = spark.sql("SELECT count(*) AS n FROM label_embeddings").first()["n"]
print(f"Embedded {label_count} labels")

2. Embed the documents

Python
spark.sql(f"""
CREATE OR REPLACE TABLE doc_embeddings AS
SELECT
{doc_id_expr} AS id,
{DOCS_TEXT_COL},
cast(
ai_query('{EMBEDDING_MODEL}', {DOCS_TEXT_COL}) AS ARRAY<FLOAT>
) AS embedding
FROM {DOCS_TABLE}
""")

doc_count = spark.sql("SELECT count(*) AS n FROM doc_embeddings").first()["n"]
print(f"Embedded {doc_count} documents")

3. Score labels by cosine similarity

Cross-join documents with labels and compute cosine similarity for every pair. This produces N_docs × N_labels rows. For very large document tables, consider partitioning doc_embeddings and running this over each partition. Top-K labels per document are aggregated later using max_by(struct, score, K).

Python
spark.sql(f"""
CREATE OR REPLACE TABLE scored_labels AS
SELECT
d.id,
l.{LABELS_KEY_COL},
{f'l.{LABELS_DESC_COL},' if LABELS_DESC_COL else ''}
vector_cosine_similarity(d.embedding, l.embedding) AS score
FROM doc_embeddings d
CROSS JOIN label_embeddings l
""")

row_count = spark.sql("SELECT count(*) AS n FROM scored_labels").first()["n"]
print(f"Scored labels: {row_count:,} rows ({doc_count} docs × {label_count} labels)")

4. Prepare a ground-truth evaluation set

K-tuning requires a small set of documents with known correct labels. If you have an existing eval table, set EVAL_TABLE in the next cell and skip the sampling cell. If not, the second cell samples documents you can manually label and re-import.

Python
# Option A: point to your existing eval table
# Must have columns: id (matching doc_embeddings.id) and ground_truth_label
EVAL_TABLE = dbutils.widgets.get("eval_table") # read from notebook widget

if EVAL_TABLE:
eval_df = spark.table(EVAL_TABLE)
print(f"Loaded {eval_df.count()} eval examples from {EVAL_TABLE}")
else:
print("No eval table set — run the next cell to sample documents for labeling.")
Python
# Option B: sample documents for manual labeling
if not EVAL_TABLE:
sample_df = spark.sql(f"""
SELECT id, {DOCS_TEXT_COL}
FROM doc_embeddings
ORDER BY rand()
LIMIT {EVAL_SAMPLE_SIZE}
""")
sample_df.display()
print(f"\nSampled {EVAL_SAMPLE_SIZE} documents.")
print("Next steps:")
print(" 1. Export these rows (copy the table above or save to CSV)")
print(" 2. Add a 'ground_truth_label' column and fill in the correct label for each doc")
print(" 3. Re-import as a Delta table and set EVAL_TABLE above")
print(" 4. Re-run cell 4 (Option A) to load it")

5. Measure Recall@K

Recall@K checks whether the ground-truth label appears in the top-K embedding candidates. This is a retrieval-only metric — it does not call ai_classify and runs instantly.

If recall is low at a given K, ai_classify cannot possibly return the right answer because the correct label was excluded from the candidate set before classification even ran.

Python
assert EVAL_TABLE, "Set EVAL_TABLE in cell 4 before running K-tuning."

spark.sql(f"CREATE OR REPLACE TEMP VIEW eval_set AS SELECT * FROM {EVAL_TABLE}")

recall_results = []
for k in K_VALUES:
row = spark.sql(f"""
SELECT
{k} AS k,
count(*) AS eval_size,
sum(CASE WHEN hit THEN 1 ELSE 0 END) AS hits,
round(sum(CASE WHEN hit THEN 1 ELSE 0 END) / count(*), 4) AS recall_at_k
FROM (
SELECT
e.id,
array_contains(
max_by(r.{LABELS_KEY_COL}, r.score, {k}),
e.ground_truth_label
) AS hit
FROM eval_set e
JOIN scored_labels r ON r.id = e.id
GROUP BY e.id, e.ground_truth_label
)
""").first()
recall_results.append(row.asDict())
print(f" K={k:>4d} → Recall@K = {row['recall_at_k']:.2%} ({row['hits']}/{row['eval_size']})")

recall_df = spark.createDataFrame(recall_results)
recall_df.display()

6. Measure end-to-end accuracy

For each K, build the top-K label set per eval doc, run ai_classify, and compare against ground truth.

This step calls ai_classify and costs more than the recall check. Start with the K values where recall is already reasonable.

Python
accuracy_results = []

for k in K_VALUES:
# Build top-K labels per eval doc using max_by aggregate
spark.sql(f"""
CREATE OR REPLACE TEMP VIEW eval_top_labels AS
SELECT
r.id,
{top_k_labels_json(k, 'r')} AS labels
FROM scored_labels r
JOIN eval_set e ON e.id = r.id
GROUP BY r.id
""")

# Materialize ai_classify first (returns VARIANT, and is non-deterministic so can't go inside aggregate)
spark.sql(f"""
CREATE OR REPLACE TEMP VIEW eval_predictions AS
SELECT
e.id,
e.ground_truth_label,
get_json_object(cast(ai_classify(d.{DOCS_TEXT_COL}, t.labels, map('version', '2.0')) as string), '$.response[0]') AS predicted_label
FROM eval_set e
JOIN doc_embeddings d ON d.id = e.id
JOIN eval_top_labels t ON t.id = e.id
""")

row = spark.sql(f"""
SELECT
{k} AS k,
count(*) AS eval_size,
sum(CASE WHEN predicted_label = ground_truth_label THEN 1 ELSE 0 END) AS correct,
round(
sum(CASE WHEN predicted_label = ground_truth_label THEN 1 ELSE 0 END) / count(*),
4
) AS accuracy
FROM eval_predictions
""").first()

accuracy_results.append(row.asDict())
print(f" K={k:>4d} → Accuracy = {row['accuracy']:.2%} ({row['correct']}/{row['eval_size']})")

accuracy_df = spark.createDataFrame(accuracy_results)
accuracy_df.display()

7. Compare results and pick K

The chart below shows Recall@K and end-to-end accuracy side by side. Pick the smallest K where accuracy stops improving — larger K means slower classification with no quality gain.

Python
import pandas as pd
import matplotlib.pyplot as plt

recall_pd = pd.DataFrame(recall_results)
accuracy_pd = pd.DataFrame(accuracy_results)
combined = recall_pd.merge(accuracy_pd, on="k", suffixes=("_recall", "_acc"))

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(combined["k"], combined["recall_at_k"], "o-", label="Recall@K", linewidth=2)
ax.plot(combined["k"], combined["accuracy"], "s--", label="End-to-end accuracy", linewidth=2)
ax.set_xlabel("K (candidate labels per document)")
ax.set_ylabel("Score")
ax.set_title("K-Tuning: Recall@K vs End-to-End Accuracy")
ax.set_ylim(0, 1.05)
ax.set_xticks(combined["k"])
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nFull results:")
print(combined[["k", "recall_at_k", "accuracy"]].to_string(index=False))
Python
# Pick your K based on the chart above
CHOSEN_K = 50 # <-- edit this

chosen_row = combined[combined["k"] == CHOSEN_K].iloc[0]
print(f"Chosen K = {CHOSEN_K}")
print(f" Recall@K: {chosen_row['recall_at_k']:.2%}")
print(f" End-to-end accuracy: {chosen_row['accuracy']:.2%}")

8. Run full classification with chosen K

Apply the selected K to your entire document table.

Python
spark.sql(f"""
CREATE OR REPLACE TABLE top_labels_per_doc AS
SELECT
id,
{top_k_labels_json(CHOSEN_K, 'r')} AS labels
FROM scored_labels r
GROUP BY id
""")

print(f"Built top-{CHOSEN_K} label sets for all documents")
Python
result_df = spark.sql(f"""
SELECT
c.{DOCS_TEXT_COL},
cast(ai_classify(c.{DOCS_TEXT_COL}, t.labels, map('version', '2.0')) as string) AS classification
FROM {DOCS_TABLE} c
JOIN top_labels_per_doc t ON t.id = {doc_id_expr.replace(DOCS_TEXT_COL, f'c.{DOCS_TEXT_COL}')}
""")

result_df.display()

# Optionally save results
# result_df.write.mode("overwrite").saveAsTable("my_catalog.my_schema.classification_results")