Perform batch LLM inference using ai_query

Preview

This feature is in Public Preview.

This article describes how to perform batch inference using the built-in Databricks SQL function ai_query with an endpoint that uses Foundation Model APIs provisioned throughput. The examples and guidance in this article are recommended for batch inference workloads that use large language models (LLM) to process multiple inputs.

You can use ai_query with either SQL or PySpark to run batch inference workloads. To run batch inference on your data, specify the following in ai_query:

  • The Unity Catalog input table and output table

  • The provisioned throughput endpoint name

  • The model prompt and any model parameters

See ai_query function for more detail about this AI function.

Requirements

  • A workspace in a Foundation Model APIs supported region.

  • One of the following:

    • All-purpose compute with compute size i3.2xlarge or larger running Databricks Runtime 15.4 ML LTS or above with at least two workers.

    • SQL warehouse medium and larger.

  • An existing model serving endpoint. See Provisioned throughput Foundation Model APIs to create a provisioned throughput endpoint.

  • Query permission on the Delta table in Unity Catalog that contains the data you want to use.

  • Set the pipelines.channel in the table properties as ‘preview’ to use ai_query(). See Examples for a sample query.

Use ai_query and SQL

The following is a batch inference example using ai_query and SQL. This example includes modelParameters with max_tokens and temperature and shows how to concatenate the prompt for your model and the input column using concat(). There are multiple ways to perform concatenation, such as using ||, concat(), or format_string().

CREATE OR REPLACE TABLE ${output_table_name} AS (
  SELECT
      ${input_column_name},
      AI_QUERY(
        "${endpoint}",
        CONCAT("${prompt}", ${input_column_name}),
        modelParameters => named_struct('max_tokens', ${num_output_tokens},'temperature', ${temperature})
      ) as response
    FROM ${input_table_name}
    LIMIT ${input_num_rows}
)

Use ai_query and PySpark

If you prefer using Python, you can also run batch inference with ai_query and PySpark as shown in the following:

df_out = df.selectExpr("ai_query('{endpoint_name}', CONCAT('{prompt}', {input_column_name}), modelParameters => named_struct('max_tokens', ${num_output_tokens},'temperature', ${temperature})) as {output_column_name}")

df_out.write.mode("overwrite").saveAsTable(output_table_name)

Batch inference example notebook using Python

The example notebook creates a provisioned throughput endpoint and runs batch LLM inference using Python and the Meta Llama 3.1 70B model. It also has guidance on benchmarking your batch inference workload and creating a provisioned throughput model serving endpoint.

LLM batch inference with a provisioned throughput endpoint notebook

Open notebook in new tab