Author AI agents in code
This article shows how to author an AI agent in Python using Mosaic AI Agent Framework and popular agent-authoring libraries like LangGraph, PyFunc, and OpenAI.
Requirements
Databricks recommends installing the latest version of the MLflow Python client when developing agents.
To author and deploy agents using the approach in this article, you must meet the following requirements:
- Install
databricks-agents
version 0.16.0 and above - Install
mlflow
version 2.20.2 and above
%pip install -U -qqqq databricks-agents>=0.16.0 mlflow>=2.20.2
Databricks also recommends installing the Databricks AI Bridge library. The AI Bridge library provides a shared layer of APIs to interact with Databricks AI features, such as Databricks AI/BI Genie and Vector Search.
This library also contains the source code for databricks-langchain and databricks-openai. These integration packages provide seamless integration of Databricks AI features to use in AI authoring frameworks.
%pip install -U -qqqq databricks-ai-bridge
Use ChatAgent
to author agents
To standardize and streamline agent authoring, Databricks recommends using MLflow's ChatAgent
interface to create production-ready AI agents. ChatAgent
is a chat schema specification designed for agent scenarios.
By using ChatAgent
, developers can create agents compatible with Databricks and MLflow tools for agent tracking, evaluation, and lifecycle management, which are essential for deploying production-ready models.
To learn how to create a ChatAgent
, see the examples in the following section and MLflow documentation - What is the ChatAgent interface.
ChatAgent
examples
The following notebooks show how to author streaming and non-streaming ChatAgents
using popular libraries like OpenAI and LangGraph.
LangGraph tool-calling agent
OpenAI tool-calling agent
OpenAI simple agent
Deploy a ChatAgent
to Databricks Model Serving
Databricks deploys ChatAgents
in a distributed environment on Databricks Model Serving, which means that during a multi-turn conversation, the same serving replica may not handle all requests. Be sure to pay attention to the following implications for managing agent state:
-
No local caching: When deploying a
ChatAgent
, do not assume that the same serving replica will serve all of the requests during a multi-turn conversation. Do not attempt to cache the state from the result from an individual conversation turn and reuse it as context in a future turn. Instead, a stateful agent must reconstruct its internal state when provided a dictionary of theChatAgentRequest
schema. -
Thread-safe state: The agent's state must be thread-safe to prevent conflicts in a multi-threaded environment.
-
Instantiate state during the
predict
function: Instantiate state each time thepredict
function is called rather than when theChatAgent
is initialized. Because a singleChatAgent
replica could handle requests from multiple conversations, storing state at theChatAgent
level could leak information between conversations and cause conflicts.
Custom inputs and outputs
Some scenarios may require additional agents inputs, such as client_type
and session_id
, or outputs like retrieval source links that should not be included in the chat history for future interactions.
For these scenarios, MLflow ChatAgent
natively supports custom_input
and custom_output
values.
The Agent Evaluation review app does not currently support rendering traces for agents with additional input fields.
See the following examples to learn how to set custom inputs and outputs for OpenAI/PyFunc and LangGraph agents.
OpenAI + PyFunc custom schema agent notebook
LangGraph custom schema agent notebook
Provide custom_inputs
in the AI Playground and agent review app
If your agent accepts additional inputs using the custom_inputs
field, you can manually provide these inputs in both the AI Playground and the agent review app.
-
In either the AI Playground or the Agent Review App, select the gear icon
.
-
Enable custom_inputs.
-
Provide a JSON object that matches your agent’s defined input schema.
Set retriever schema
AI agents often use retrievers, a type of agent tool that finds and returns relevant documents using a Vector Search index. For more information on retrievers, see Unstructured retrieval AI agent tools.
To ensure that retrievers are traced properly, call mlflow.models.set_retriever_schema when you define your agent in code. Use set_retriever_schema
to map the column names in the returned table to MLflow’s expected fields such as primary_key
, text_column
, and doc_uri
.
# Define the retriever's schema by providing your column names
# These strings should be read from a config dictionary
mlflow.models.set_retriever_schema(
name="vector_search",
primary_key="chunk_id",
text_column="text_column",
doc_uri="doc_uri"
# other_columns=["column1", "column2"],
)
The doc_uri
column is especially important when evaluating the retriever’s performance. doc_uri
is the main identifier for documents returned by the retriever, allowing you to compare them against ground truth evaluation sets. See Evaluation sets.
You can also specify additional columns in your retriever’s schema by providing a list of column names with the other_columns
field.
If you have multiple retrievers, you can define multiple schemas by using unique names for each retriever schema.
Use parameters to configure the agent
In the Agent Framework, you can use parameters to control how agents are executed. This allows you to quickly iterate by varying characteristics of your agent without changing the code. Parameters are key-value pairs that you define in a Python dictionary or a .yaml
file.
To configure the code, create a ModelConfig
, a set of key-value parameters. ModelConfig
is either a Python dictionary or a .yaml
file. For example, you can use a dictionary during development and then convert it to a .yaml
file for production deployment and CI/CD. For details about ModelConfig
, see the MLflow documentation.
An example ModelConfig
is shown below.
llm_parameters:
max_tokens: 500
temperature: 0.01
model_serving_endpoint: databricks-dbrx-instruct
vector_search_index: ml.docs.databricks_docs_index
prompt_template: 'You are a hello world bot. Respond with a reply to the user''s
question that indicates your prompt template came from a YAML file. Your response
must use the word "YAML" somewhere. User''s question: {question}'
prompt_template_input_vars:
- question
To call the configuration from your code, use one of the following:
# Example for loading from a .yml file
config_file = "configs/hello_world_config.yml"
model_config = mlflow.models.ModelConfig(development_config=config_file)
# Example of using a dictionary
config_dict = {
"prompt_template": "You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}",
"prompt_template_input_vars": ["question"],
"model_serving_endpoint": "databricks-dbrx-instruct",
"llm_parameters": {"temperature": 0.01, "max_tokens": 500},
}
model_config = mlflow.models.ModelConfig(development_config=config_dict)
# Use model_config.get() to retrieve a parameter value
value = model_config.get('sample_param')
Streaming error propagation
Mosaic AI propagation any errors encountered while streaming with the last token under databricks_output.error
. It is up to the calling client to properly handle and surface this error.
{
"delta": …,
"databricks_output": {
"trace": {...},
"error": {
"error_code": BAD_REQUEST,
"message": "TimeoutException: Tool XYZ failed to execute"
}
}
}
Author a ChatModel agent
ChatModel
is an MLflow class that is available as an alternative to the ChatAgent
interface. ChatModel
extends OpenAI’s ChatCompletion schema, allowing you to maintain broad compatibility with platforms supporting the ChatCompletion standard while adding custom functionality.
Databricks recommends ChatModel
only when strict adherence to the OpenAI ChatCompletion signature is required. For all other scenarios, Databricks recommends ChatAgent
.
The following table compares the interface options:
Interface | Supported features | Recommendations |
---|---|---|
| Recommended for basic and advanced agent scenarios. Creating a basic agent with | |
| Recommended when strict adherence to the OpenAI ChatCompletion signature is required. |
See MLflow: Getting Started with ChatModel.
You can author your agent as a subclass of mlflow.pyfunc.ChatModel. This method provides the following benefits:
- Enables streaming agent output when invoking a served agent (by passing
{stream: true}
in the request body). - Automatically enables AI gateway inference tables when your agent is served, providing access to enhanced request log metadata, such as the requester name.
- Allows you to write agent code compatible with the ChatCompletion schema using typed Python classes.
- MLflow automatically infers a chat completion-compatible signature when logging the agent, even without an
input_example
. This simplifies the process of registering and deploying the agent. See Infer Model Signature during logging.
The following code is best executed in a Databricks notebook. Notebooks provide a convenient environment for developing, testing, and iterating on your agent.
The MyAgent
class extends mlflow.pyfunc.ChatModel
, implementing the required predict
method. This ensures compatibility with Mosaic AI Agent Framework.
The class also includes the optional methods _create_chat_completion_chunk
and predict_stream
to handle streaming outputs.
import re
from dataclasses import dataclass
from typing import Optional, Dict, List, Generator
from mlflow.pyfunc import ChatModel
from mlflow.types.llm import (
# Non-streaming helper classes
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
ChatChoice,
ChatParams,
# Helper classes for streaming agent output
ChatChoiceDelta,
ChatChunkChoice,
)
class MyAgent(ChatModel):
"""
Defines a custom agent that processes ChatCompletionRequests
and returns ChatCompletionResponses.
"""
def predict(self, context, messages: list[ChatMessage], params: ChatParams) -> ChatCompletionResponse:
last_user_question_text = messages[-1].content
response_message = ChatMessage(
role="assistant",
content=(
f"I will always echo back your last question. Your last question was: {last_user_question_text}. "
)
)
return ChatCompletionResponse(
choices=[ChatChoice(message=response_message)]
)
def _create_chat_completion_chunk(self, content) -> ChatCompletionChunk:
"""Helper for constructing a ChatCompletionChunk instance for wrapping streaming agent output"""
return ChatCompletionChunk(
choices=[ChatChunkChoice(
delta=ChatChoiceDelta(
role="assistant",
content=content
)
)]
)
def predict_stream(
self, context, messages: List[ChatMessage], params: ChatParams
) -> Generator[ChatCompletionChunk, None, None]:
last_user_question_text = messages[-1].content
yield self._create_chat_completion_chunk(f"Echoing back your last question, word by word.")
for word in re.findall(r"\S+\s*", last_user_question_text):
yield self._create_chat_completion_chunk(word)
agent = MyAgent()
model_input = ChatCompletionRequest(
messages=[ChatMessage(role="user", content="What is Databricks?")]
)
response = agent.predict(context=None, model_input=model_input)
print(response)
While the agent class MyAgent
is defined in one notebook, you should create a separate driver notebook. The driver notebook logs the agent to Model Registry and deploys the agent using Model Serving.
This separation follows the workflow recommended by Databricks for logging models using MLflow’s Models from Code methodology.
Wrap LangChain in ChatModel
If you have an existing LangChain model and want to integrate it with other Mosaic AI agent features, you can wrap it in an MLflow ChatModel
to ensure compatibility.
This code sample uses the following steps to wrap a LangChain runnable as a ChatModel
:
- Wrap the final output of the LangChain with
mlflow.langchain.output_parsers.ChatCompletionOutputParser
to produce a chat completion output signature. - The
LangchainAgent
class extendsmlflow.pyfunc.ChatModel
and implements two key methods:predict
: Handles synchronous predictions by invoking the chain and returning a formatted response.predict_stream
: Handles streaming predictions by invoking the chain and yielding chunks of responses.
from mlflow.langchain.output_parsers import ChatCompletionOutputParser
from mlflow.pyfunc import ChatModel
from typing import Optional, Dict, List, Generator
from mlflow.types.llm import (
ChatCompletionResponse,
ChatCompletionChunk
)
chain = (
<your chain here>
| ChatCompletionOutputParser()
)
class LangchainAgent(ChatModel):
def _prepare_messages(self, messages: List[ChatMessage]):
return {"messages": [m.to_dict() for m in messages]}
def predict(
self, context, messages: List[ChatMessage], params: ChatParams
) -> ChatCompletionResponse:
question = self._prepare_messages(messages)
response_message = self.chain.invoke(question)
return ChatCompletionResponse.from_dict(response_message)
def predict_stream(
self, context, messages: List[ChatMessage], params: ChatParams
) -> Generator[ChatCompletionChunk, None, None]:
question = self._prepare_messages(messages)
for chunk in chain.stream(question):
yield ChatCompletionChunk.from_dict(chunk)