Pular para o conteúdo principal

Classifique documentos com mais de 500 rótulos

ai_classify aceita até 500 rótulos por chamada. Para taxonomias maiores, pré-filtre rótulos por documento usando similaridade de incorporação e, em seguida, chame ai_classify na lista reduzida de principais candidatos K. Este tutorial mostra como encontrar o K ideal — o menor número de candidatos que preserva a precisão.

nota

vector_cosine_similarity está disponível apenas no Databricks Runtime 18.1 e acima.

Antes de começar

  • Um workspace habilitado para Unity Catalog com acesso a ai_classify (consulte a disponibilidade).
  • Databricks Runtime 18.1 e superior (requerido para vector_cosine_similarity).
  • Uma tabela Delta de documentos para classificar.
  • Uma tabela Delta de rótulos com uma coluna key e uma coluna de descrição opcional.
  • Um pequeno conjunto de avaliação com rótulos da verdade fundamental, ou a capacidade de criar um (Opção B no passo 4).

0. Configuração

Defina seus nomes de tabela, nomes de coluna e modelo de embedding. A função auxiliar top_k_labels_json cria a expressão JSON passada para ai_classify para os candidatos principais K de cada documento.

Python
# -- Your tables --
DOCS_TABLE = "path.to.your_docs_table" # table of documents to classify
DOCS_TEXT_COL = "document" # column with text to classify
DOCS_ID_COL = None # unique ID column; set to None to auto-generate via md5

LABELS_TABLE = "path.to.your_labels_table" # table of labels
LABELS_KEY_COL = "label" # column with label value
LABELS_DESC_COL = "description" # description column; set to None if labels have no descriptions

# -- Embedding model --
EMBEDDING_MODEL = "databricks-qwen3-embedding-0-6b" # compact model, good default for English text

# -- K values to sweep --
K_VALUES = [10, 20, 50, 100, 200, 500]

# -- Eval set size (if you need to create one) --
EVAL_SAMPLE_SIZE = 100 # docs to sample for manual labeling
Python
doc_id_expr = DOCS_ID_COL if DOCS_ID_COL else f"md5({DOCS_TEXT_COL})"
label_embed_text = (
f"concat({LABELS_KEY_COL}, ': ', {LABELS_DESC_COL})"
if LABELS_DESC_COL
else LABELS_KEY_COL
)
def top_k_labels_json(k_val, prefix=""):
"""Build a JSON expression for top-K labels using max_by aggregate."""
col_prefix = f"{prefix}." if prefix else ""
if LABELS_DESC_COL:
return f"to_json(map_from_entries(max_by(struct({col_prefix}{LABELS_KEY_COL}, {col_prefix}{LABELS_DESC_COL}), {col_prefix}score, {k_val})))"
else:
return f"to_json(max_by({col_prefix}{LABELS_KEY_COL}, {col_prefix}score, {k_val}))"

print(f"Docs table: {DOCS_TABLE} (text: {DOCS_TEXT_COL}, id: {doc_id_expr})")
print(f"Labels table: {LABELS_TABLE} (key: {LABELS_KEY_COL}, desc: {LABELS_DESC_COL})")
print(f"Embed text: {label_embed_text}")
print(f"K sweep: {K_VALUES}")

1. Incorpore os rótulos

Execute isto uma vez. Executar novamente apenas quando a taxonomia mudar.

Python
spark.sql(f"""
CREATE OR REPLACE TABLE label_embeddings AS
SELECT
{LABELS_KEY_COL},
{f'{LABELS_DESC_COL},' if LABELS_DESC_COL else ''}
cast(
ai_query('{EMBEDDING_MODEL}', {label_embed_text}) AS ARRAY<FLOAT>
) AS embedding
FROM {LABELS_TABLE}
""")

label_count = spark.sql("SELECT count(*) AS n FROM label_embeddings").first()["n"]
print(f"Embedded {label_count} labels")

2. Incorporar os documentos

Python
spark.sql(f"""
CREATE OR REPLACE TABLE doc_embeddings AS
SELECT
{doc_id_expr} AS id,
{DOCS_TEXT_COL},
cast(
ai_query('{EMBEDDING_MODEL}', {DOCS_TEXT_COL}) AS ARRAY<FLOAT>
) AS embedding
FROM {DOCS_TABLE}
""")

doc_count = spark.sql("SELECT count(*) AS n FROM doc_embeddings").first()["n"]
print(f"Embedded {doc_count} documents")

3. Pontuar rótulos por similaridade de cosseno

Faça o cross-join de documentos com rótulos e compute a similaridade de cosseno para cada par. Isto produz N_docs × N_labels linhas. Para tabelas de documentos muito grandes, considere o particionamento de doc_embeddings e a execução disso em cada partição. Rótulos Top-K por documento são agregados posteriormente usando max_by(struct, score, K).

Python
spark.sql(f"""
CREATE OR REPLACE TABLE scored_labels AS
SELECT
d.id,
l.{LABELS_KEY_COL},
{f'l.{LABELS_DESC_COL},' if LABELS_DESC_COL else ''}
vector_cosine_similarity(d.embedding, l.embedding) AS score
FROM doc_embeddings d
CROSS JOIN label_embeddings l
""")

row_count = spark.sql("SELECT count(*) AS n FROM scored_labels").first()["n"]
print(f"Scored labels: {row_count:,} rows ({doc_count} docs × {label_count} labels)")

4. Prepare um conjunto de avaliação de verdade fundamental

K-tuning requer um pequeno conjunto de documentos com rótulos corretos conhecidos. Se houver uma tabela de avaliação existente, defina EVAL_TABLE na próxima célula e ignore a célula de amostragem. Caso contrário, a segunda célula apresenta documentos de amostra que você pode rotular manualmente e reimportar.

Python
# Option A: point to your existing eval table
# Must have columns: id (matching doc_embeddings.id) and ground_truth_label
EVAL_TABLE = dbutils.widgets.get("eval_table") # read from notebook widget

if EVAL_TABLE:
eval_df = spark.table(EVAL_TABLE)
print(f"Loaded {eval_df.count()} eval examples from {EVAL_TABLE}")
else:
print("No eval table set — run the next cell to sample documents for labeling.")
Python
# Option B: sample documents for manual labeling
if not EVAL_TABLE:
sample_df = spark.sql(f"""
SELECT id, {DOCS_TEXT_COL}
FROM doc_embeddings
ORDER BY rand()
LIMIT {EVAL_SAMPLE_SIZE}
""")
sample_df.display()
print(f"\nSampled {EVAL_SAMPLE_SIZE} documents.")
print("Next steps:")
print(" 1. Export these rows (copy the table above or save to CSV)")
print(" 2. Add a 'ground_truth_label' column and fill in the correct label for each doc")
print(" 3. Re-import as a Delta table and set EVAL_TABLE above")
print(" 4. Re-run cell 4 (Option A) to load it")

5. Medir Recall@K

Recall@K verifica se o rótulo de verdade fundamental aparece nos candidatos de embedding top-K. Esta é uma métrica somente de recuperação — não chama ai_classify e é executada instantaneamente.

Se a revocação for baixa para um determinado K, ai_classify não poderá retornar a resposta correta porque o rótulo correto foi excluído do conjunto de candidatos antes mesmo que a classificação fosse para execução.

Python
assert EVAL_TABLE, "Set EVAL_TABLE in cell 4 before running K-tuning."

spark.sql(f"CREATE OR REPLACE TEMP VIEW eval_set AS SELECT * FROM {EVAL_TABLE}")

recall_results = []
for k in K_VALUES:
row = spark.sql(f"""
SELECT
{k} AS k,
count(*) AS eval_size,
sum(CASE WHEN hit THEN 1 ELSE 0 END) AS hits,
round(sum(CASE WHEN hit THEN 1 ELSE 0 END) / count(*), 4) AS recall_at_k
FROM (
SELECT
e.id,
array_contains(
max_by(r.{LABELS_KEY_COL}, r.score, {k}),
e.ground_truth_label
) AS hit
FROM eval_set e
JOIN scored_labels r ON r.id = e.id
GROUP BY e.id, e.ground_truth_label
)
""").first()
recall_results.append(row.asDict())
print(f" K={k:&gt;4d} → Recall@K = {row['recall_at_k']:.2%} ({row['hits']}/{row['eval_size']})")

recall_df = spark.createDataFrame(recall_results)
recall_df.display()

6. Medir a precisão de ponta a ponta

Para cada K, construa o conjunto de rótulos top-K por documento de avaliação, execute ai_classify e compare com a verdade fundamental.

Este passo chama ai_classify e custa mais do que a verificação do Recall. Comece com os valores K onde o recall já é razoável.

Python
accuracy_results = []

for k in K_VALUES:
# Build top-K labels per eval doc using max_by aggregate
spark.sql(f"""
CREATE OR REPLACE TEMP VIEW eval_top_labels AS
SELECT
r.id,
{top_k_labels_json(k, 'r')} AS labels
FROM scored_labels r
JOIN eval_set e ON e.id = r.id
GROUP BY r.id
""")

# Materialize ai_classify first (returns VARIANT, and is non-deterministic so can't go inside aggregate)
spark.sql(f"""
CREATE OR REPLACE TEMP VIEW eval_predictions AS
SELECT
e.id,
e.ground_truth_label,
get_json_object(cast(ai_classify(d.{DOCS_TEXT_COL}, t.labels, map('version', '2.0')) as string), '$.response[0]') AS predicted_label
FROM eval_set e
JOIN doc_embeddings d ON d.id = e.id
JOIN eval_top_labels t ON t.id = e.id
""")

row = spark.sql(f"""
SELECT
{k} AS k,
count(*) AS eval_size,
sum(CASE WHEN predicted_label = ground_truth_label THEN 1 ELSE 0 END) AS correct,
round(
sum(CASE WHEN predicted_label = ground_truth_label THEN 1 ELSE 0 END) / count(*),
4
) AS accuracy
FROM eval_predictions
""").first()

accuracy_results.append(row.asDict())
print(f" K={k:&gt;4d} → Accuracy = {row['accuracy']:.2%} ({row['correct']}/{row['eval_size']})")

accuracy_df = spark.createDataFrame(accuracy_results)
accuracy_df.display()

7. Comparar resultados e escolher K

O gráfico abaixo mostra Recall@K e a precisão de ponta a ponta lado a lado. Escolha o **menor K onde a precisão deixa de melhorar** — um K maior significa uma classificação mais lenta sem ganho de qualidade.

Python
import pandas as pd
import matplotlib.pyplot as plt

recall_pd = pd.DataFrame(recall_results)
accuracy_pd = pd.DataFrame(accuracy_results)
combined = recall_pd.merge(accuracy_pd, on="k", suffixes=("_recall", "_acc"))

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(combined["k"], combined["recall_at_k"], "o-", label="Recall@K", linewidth=2)
ax.plot(combined["k"], combined["accuracy"], "s--", label="End-to-end accuracy", linewidth=2)
ax.set_xlabel("K (candidate labels per document)")
ax.set_ylabel("Score")
ax.set_title("K-Tuning: Recall@K vs End-to-End Accuracy")
ax.set_ylim(0, 1.05)
ax.set_xticks(combined["k"])
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nFull results:")
print(combined[["k", "recall_at_k", "accuracy"]].to_string(index=False))
Python
# Pick your K based on the chart above
CHOSEN_K = 50 # <-- edit this

chosen_row = combined[combined["k"] == CHOSEN_K].iloc[0]
print(f"Chosen K = {CHOSEN_K}")
print(f" Recall@K: {chosen_row['recall_at_k']:.2%}")
print(f" End-to-end accuracy: {chosen_row['accuracy']:.2%}")

8. Executar classificação completa com o K escolhido

Aplique o K selecionado a toda a tabela de documentos.

Python
spark.sql(f"""
CREATE OR REPLACE TABLE top_labels_per_doc AS
SELECT
id,
{top_k_labels_json(CHOSEN_K, 'r')} AS labels
FROM scored_labels r
GROUP BY id
""")

print(f"Built top-{CHOSEN_K} label sets for all documents")
Python
result_df = spark.sql(f"""
SELECT
c.{DOCS_TEXT_COL},
cast(ai_classify(c.{DOCS_TEXT_COL}, t.labels, map('version', '2.0')) as string) AS classification
FROM {DOCS_TABLE} c
JOIN top_labels_per_doc t ON t.id = {doc_id_expr.replace(DOCS_TEXT_COL, f'c.{DOCS_TEXT_COL}')}
""")

result_df.display()

# Optionally save results
# result_df.write.mode("overwrite").saveAsTable("my_catalog.my_schema.classification_results")