%md [‹ Back to Table of Contents](index.html)
%md # Mosaic AI Agent Framework: Author and deploy a tool-calling OpenAI agent This notebook demonstrates how to author an OpenAI agent that's compatible with Mosaic AI Agent Framework features. In this notebook you learn to: - Author a tool-calling OpenAI `ChatAgent` - Manually test the agent's output - Evaluate the agent using Mosaic AI Agent Evaluation - Log and deploy the agent To learn more about authoring an agent using Mosaic AI Agent Framework, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/author-agent) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/create-chat-model)). ## Prerequisites - Address all `TODO`s in this notebook.
Mosaic AI Agent Framework: Author and deploy a tool-calling OpenAI agent
This notebook demonstrates how to author an OpenAI agent that's compatible with Mosaic AI Agent Framework features. In this notebook you learn to:
- Author a tool-calling OpenAI
ChatAgent
- Manually test the agent's output
- Evaluate the agent using Mosaic AI Agent Evaluation
- Log and deploy the agent
To learn more about authoring an agent using Mosaic AI Agent Framework, see Databricks documentation (AWS | Azure).
Prerequisites
- Address all
TODO
s in this notebook.
%pip install -U -qqqq mlflow backoff databricks-openai databricks-agents uv dbutils.library.restartPython()
%md ## Define the agent in code Define the agent code in a single cell below. This lets you easily write the agent code to a local Python file, using the `%%writefile` magic command, for subsequent logging and deployment. #### Agent tools This agent code adds the built-in Unity Catalog function `system.ai.python_exec` to the agent. The agent code also includes commented-out sample code for adding a vector search index to perform unstructured data retrieval. For more examples of tools to add to your agent, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-tool) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-framework/agent-tool))
Define the agent in code
Define the agent code in a single cell below. This lets you easily write the agent code to a local Python file, using the %%writefile
magic command, for subsequent logging and deployment.
Agent tools
This agent code adds the built-in Unity Catalog function system.ai.python_exec
to the agent. The agent code also includes commented-out sample code for adding a vector search index to perform unstructured data retrieval.
For more examples of tools to add to your agent, see Databricks documentation (AWS | Azure)
%%writefile agent.py import json from typing import Any, Callable, Dict, Generator, List, Optional from uuid import uuid4 import backoff import mlflow import openai from databricks.sdk import WorkspaceClient from databricks_openai import VectorSearchRetrieverTool, UCFunctionToolkit from unitycatalog.ai.core.base import get_uc_function_client from mlflow.entities import SpanType from mlflow.pyfunc import ChatAgent from mlflow.types.agent import ( ChatAgentChunk, ChatAgentMessage, ChatAgentResponse, ChatContext, ) from openai import OpenAI from pydantic import BaseModel ############################################ # Define your LLM endpoint and system prompt ############################################ # TODO: Replace with your model serving endpoint LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet" # TODO: Update with your system prompt SYSTEM_PROMPT = """ You are a helpful assistant. """ ############################################################################### ## Define tools for your agent, enabling it to retrieve data or take actions ## beyond text generation ## To create and see usage examples of more tools, see ## https://docs.databricks.com/en/generative-ai/agent-framework/agent-tool.html ############################################################################### class ToolInfo(BaseModel): name: str spec: dict exec_fn: Callable TOOL_INFOS = [] # You can use UDFs in Unity Catalog as agent tools # Below, we add the `system.ai.python_exec` UDF, which provides # a python code interpreter tool to our agent # TODO: Add additional tools UC_TOOL_NAMES = ["system.ai.python_exec"] uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES) uc_function_client = get_uc_function_client() for tool_spec in uc_toolkit.tools: # TODO: comment this line out if not using with databricks-claude-3-7-sonnet del tool_spec["function"]["strict"] tool_name = tool_spec["function"]["name"] udf_name = tool_name.replace("__", ".") # Define a wrapper that accepts kwargs for the UC tool call, # then passes them to the UC tool execution client def execute_uc_tool(**kwargs): function_result = uc_function_client.execute_function(udf_name, kwargs) if function_result.error is not None: return function_result.error else: return function_result.value TOOL_INFOS.append(ToolInfo(name=tool_name, spec=tool_spec, exec_fn=execute_uc_tool)) # Use Databricks vector search indexes as tools # See https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html#locally-develop-vector-search-retriever-tools-with-ai-bridge # for details VECTOR_SEARCH_TOOLS = [] # TODO: Add vector search indexes # VECTOR_SEARCH_TOOLS.append( # VectorSearchRetrieverTool( # index_name="", # # filters="..." # ) # ) for vs_tool in VECTOR_SEARCH_TOOLS: TOOL_INFOS.append( ToolInfo( name=vs_tool.tool["function"]["name"], spec=vs_tool.tool, exec_fn=vs_tool.execute, ) ) class ToolCallingAgent(ChatAgent): """ Class representing a tool-calling Agent """ def get_tool_specs(self): """ Returns tool specifications in the format OpenAI expects. """ return [tool_info.spec for tool_info in self._tools_dict.values()] @mlflow.trace(span_type=SpanType.TOOL) def execute_tool(self, tool_name: str, args: dict) -> Any: """ Executes the specified tool with the given arguments. Args: tool_name (str): The name of the tool to execute. args (dict): Arguments for the tool. Returns: Any: The tool's output. """ if tool_name not in self._tools_dict: raise ValueError(f"Unknown tool: {tool_name}") return self._tools_dict[tool_name].exec_fn(**args) def __init__(self, llm_endpoint: str, tools: Dict[str, Dict[str, Any]]): """ Initializes the ToolCallingAgent with tools. Args: tools (Dict[str, Dict[str, Any]]): A dictionary where each key is a tool name, and the value is a dictionary containing: - "spec" (dict): JSON description of the tool (matches OpenAI format) - "function" (Callable): Function that implements the tool logic """ super().__init__() self.llm_endpoint = llm_endpoint self.workspace_client = WorkspaceClient() self.model_serving_client: OpenAI = ( self.workspace_client.serving_endpoints.get_open_ai_client() ) self._tools_dict = { tool.name: tool for tool in tools } # Store tools for later execution def prepare_messages_for_llm(self, messages: list[ChatAgentMessage]) -> list[dict[str, Any]]: """Filter out ChatAgentMessage fields that are not compatible with LLM message formats""" compatible_keys = ["role", "content", "name", "tool_calls", "tool_call_id"] return [ {k: v for k, v in m.model_dump_compat(exclude_none=True).items() if k in compatible_keys} for m in messages ] @mlflow.trace(span_type=SpanType.AGENT) def predict( self, messages: List[ChatAgentMessage], context: Optional[ChatContext] = None, custom_inputs: Optional[dict[str, Any]] = None, ) -> ChatAgentResponse: """ Primary function that takes a user's request and generates a response. """ # NOTE: this assumes that each chunk streamed by self.call_and_run_tools contains # a full message (i.e. chunk.delta is a complete message). # This is simple to implement, but you can also stream partial response messages from predict_stream, # and aggregate them in predict_stream by message ID response_messages = [ chunk.delta for chunk in self.predict_stream(messages, context, custom_inputs) ] return ChatAgentResponse(messages=response_messages) @mlflow.trace(span_type=SpanType.AGENT) def predict_stream( self, messages: List[ChatAgentMessage], context: Optional[ChatContext] = None, custom_inputs: Optional[dict[str, Any]] = None, ) -> Generator[ChatAgentChunk, None, None]: if len(messages) == 0: raise ValueError( "The list of `messages` passed to predict(...) must contain at least one message" ) all_messages = [ ChatAgentMessage(role="system", content=SYSTEM_PROMPT) ] + messages for message in self.call_and_run_tools(messages=all_messages): yield ChatAgentChunk(delta=message) @backoff.on_exception(backoff.expo, openai.RateLimitError) def chat_completion(self, messages: List[ChatAgentMessage]): return self.model_serving_client.chat.completions.create( model=self.llm_endpoint, messages=self.prepare_messages_for_llm(messages), tools=self.get_tool_specs(), ) @mlflow.trace(span_type=SpanType.AGENT) def call_and_run_tools( self, messages, max_iter=10 ) -> Generator[ChatAgentMessage, None, None]: current_msg_history = messages.copy() for i in range(max_iter): with mlflow.start_span(span_type="AGENT", name=f"iteration_{i + 1}"): # Get an assistant response from the model, add it to the running history # and yield it to the caller # NOTE: we perform a simple non-streaming chat completions here # Use the streaming API if you'd like to additionally do token streaming # of agent output. response = self.chat_completion(messages=current_msg_history) llm_message = response.choices[0].message assistant_message = ChatAgentMessage(**llm_message.to_dict(), id=str(uuid4())) current_msg_history.append(assistant_message) yield assistant_message tool_calls = assistant_message.tool_calls if not tool_calls: return # Stop streaming if no tool calls are needed # Execute tool calls, add them to the running message history, # and yield their results as tool messages for tool_call in tool_calls: function = tool_call.function args = json.loads(function.arguments) # Cast tool result to a string, since not all tools return as tring result = str(self.execute_tool(tool_name=function.name, args=args)) tool_call_msg = ChatAgentMessage( role="tool", name=function.name, tool_call_id=tool_call.id, content=result, id=str(uuid4()) ) current_msg_history.append(tool_call_msg) yield tool_call_msg yield ChatAgentMessage( content=f"I'm sorry, I couldn't determine the answer after trying {max_iter} times.", role="assistant", id=str(uuid4()) ) # Log the model using MLflow mlflow.openai.autolog() AGENT = ToolCallingAgent(llm_endpoint=LLM_ENDPOINT_NAME, tools=TOOL_INFOS) mlflow.models.set_model(AGENT)
%md ## Test the agent Interact with the agent to test its output. Since we manually traced methods within `ChatAgent`, you can view the trace for each step the agent takes, with any LLM calls made via the OpenAI SDK automatically traced by autologging. Replace this placeholder input with an appropriate domain-specific example for your agent.
Test the agent
Interact with the agent to test its output. Since we manually traced methods within ChatAgent
, you can view the trace for each step the agent takes, with any LLM calls made via the OpenAI SDK automatically traced by autologging.
Replace this placeholder input with an appropriate domain-specific example for your agent.
dbutils.library.restartPython()
from agent import AGENT AGENT.predict({"messages": [{"role": "user", "content": "What is 4*3 in Python?"}]})
for chunk in AGENT.predict_stream( {"messages": [{"role": "user", "content": "What is 4*3 in Python?"}]} ): print(chunk, "-----------\n")
%md ## Log the agent as an MLflow model Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code). ### Enable automatic authentication for Databricks resources For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint. To enable automatic authentication, specify the dependent Databricks resources when calling `mlflow.pyfunc.log_model().` - **TODO**: If your Unity Catalog tool queries a [vector search index](docs link) or leverages [external functions](docs link), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs ([AWS](https://docs.databricks.com/generative-ai/agent-framework/log-agent.html#specify-resources-for-automatic-authentication-passthrough) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/log-agent#resources)).
Log the agent as an MLflow model
Log the agent as code from the agent.py
file. See MLflow - Models from Code.
Enable automatic authentication for Databricks resources
For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint.
To enable automatic authentication, specify the dependent Databricks resources when calling mlflow.pyfunc.log_model().
- TODO: If your Unity Catalog tool queries a vector search index or leverages external functions, you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs (AWS | Azure).
# Determine Databricks resources to specify for automatic auth passthrough at deployment time from agent import LLM_ENDPOINT_NAME, UC_TOOL_NAMES, VECTOR_SEARCH_TOOLS import mlflow from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint from pkg_resources import get_distribution resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)] for tool in VECTOR_SEARCH_TOOLS: resources.extend(tool.resources) for tool_name in UC_TOOL_NAMES: resources.append(DatabricksFunction(function_name=tool_name)) with mlflow.start_run(): logged_agent_info = mlflow.pyfunc.log_model( artifact_path="agent", python_model="agent.py", extra_pip_requirements=[f"databricks-connect=={get_distribution('databricks-connect').version}"], resources=resources, )
%md ## Evaluate the agent with Agent Evaluation Use Mosaic AI Agent Evaluation to evalaute the agent's responses based on expected responses and other evaluation criteria. Use the evaluation criteria you specify to guide iterations, using MLflow to track the computed quality metrics. See Databricks documentation ([AWS]((https://docs.databricks.com/aws/generative-ai/agent-evaluation) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-evaluation/)). To evaluate your tool calls, add custom metrics. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/agent-evaluation/custom-metrics.html#evaluating-tool-calls) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-evaluation/custom-metrics#evaluating-tool-calls)).
Evaluate the agent with Agent Evaluation
Use Mosaic AI Agent Evaluation to evalaute the agent's responses based on expected responses and other evaluation criteria. Use the evaluation criteria you specify to guide iterations, using MLflow to track the computed quality metrics. See Databricks documentation (AWS | Azure).
To evaluate your tool calls, add custom metrics. See Databricks documentation (AWS | Azure).
import pandas as pd eval_examples = [ { "request": {"messages": [{"role": "user", "content": "What is an LLM agent?"}]}, "expected_response": None, } ] eval_dataset = pd.DataFrame(eval_examples) display(eval_dataset)
import mlflow with mlflow.start_run(run_id=logged_agent_info.run_id): eval_results = mlflow.evaluate( f"runs:/{logged_agent_info.run_id}/agent", data=eval_dataset, # Your evaluation dataset model_type="databricks-agent", # Enable Mosaic AI Agent Evaluation ) # Review the evaluation results in the MLFLow UI (see console output), or access them in place: display(eval_results.tables["eval_results"])
%md ## Pre-deployment agent validation Before registering and deploying the agent, perform pre-deployment checks using the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See Databricks documentation ([AWS](https://docs.databricks.com/en/machine-learning/model-serving/model-serving-debug.html#validate-inputs) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/model-serving-debug#before-model-deployment-validation-checks)).
Pre-deployment agent validation
Before registering and deploying the agent, perform pre-deployment checks using the mlflow.models.predict() API. See Databricks documentation (AWS | Azure).
mlflow.models.predict( model_uri=f"runs:/{logged_agent_info.run_id}/agent", input_data={"messages": [{"role": "user", "content": "Hello!"}]}, env_manager="uv", )
%md ## Register the model to Unity Catalog Before you deploy the agent, you must register the agent to Unity Catalog. - **TODO** Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.
Register the model to Unity Catalog
Before you deploy the agent, you must register the agent to Unity Catalog.
- TODO Update the
catalog
,schema
, andmodel_name
below to register the MLflow model to Unity Catalog.
mlflow.set_registry_uri("databricks-uc") # TODO: define the catalog, schema, and model name for your UC model catalog = "" schema = "" model_name = "" UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}" # register the model to UC uc_registered_model_info = mlflow.register_model(model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME)
%md ## Deploy the agent
Deploy the agent
from databricks import agents agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags = {"endpointSource": "docs"})
%md ## Next steps After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See docs ([AWS](https://docs.databricks.com/en/generative-ai/deploy-agent.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/deploy-agent)) for details