Skip to main content

ai_classify

Classifies document content into one of the provided labels using AI/LLM.

For the corresponding Databricks SQL function, see ai_classify function.

Syntax

Python
from pyspark.sql import functions as dbf

dbf.ai_classify(col=<col>, labels=<labels>, options=<options>)

Parameters

Parameter

Type

Description

col

pyspark.sql.Column or str

A column containing the document content to classify.

labels

list, dict, pyspark.sql.Column, or str

Either a literal label set (Python list of label strings or dict mapping label names to descriptions, serialized to a JSON literal automatically) or a column expression whose per-row value is a JSON array of label strings or a JSON object mapping label names to descriptions.

options

dict, optional

A dictionary of options to control classification behavior.

Returns

pyspark.sql.Column: A new column containing the classification result.

The default behavior is single-label classification. To enable multi-label classification and see the full set of supported options, see the SQL language manual.

Examples

Python
# Static labels (same set for every row)
df.select(ai_classify("text", ["positive", "negative", "neutral"]))
df.select(ai_classify("text", {"positive": "Happy tone", "negative": "Unhappy tone"}))

# Per-row labels (a column whose value is a JSON array or JSON object string)
df.select(ai_classify("text", col("labels_json")))