%md [‹ Back to Table of Contents](index.html)
%md # Mosaic AI Agent Framework: Author a custom schema OpenAI agent This notebook demonstrates how to author a OpenAI agent that's compatible with Mosaic AI Agent Framework features. In this notebook you learn to: - Author a tool-calling OpenAI agent wrapped with `ChatAgent` with custom inputs and outputs - Manually test the agent's output - Log and deploy the agent To learn more about authoring an agent using Mosaic AI Agent Framework, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/author-agent) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/create-chat-model)). ## Prerequisites - Address all `TODO`s in this notebook.
Mosaic AI Agent Framework: Author a custom schema OpenAI agent
This notebook demonstrates how to author a OpenAI agent that's compatible with Mosaic AI Agent Framework features. In this notebook you learn to:
- Author a tool-calling OpenAI agent wrapped with
ChatAgent
with custom inputs and outputs - Manually test the agent's output
- Log and deploy the agent
To learn more about authoring an agent using Mosaic AI Agent Framework, see Databricks documentation (AWS | Azure).
Prerequisites
- Address all
TODO
s in this notebook.
%pip install -U -qqqq mlflow openai databricks-agents uv dbutils.library.restartPython()
%md ## Define the agent in code Define the agent code in a single cell below. This lets you easily write the agent code to a local Python file, using the `%%writefile` magic command, for subsequent logging and deployment.
Define the agent in code
Define the agent code in a single cell below. This lets you easily write the agent code to a local Python file, using the %%writefile
magic command, for subsequent logging and deployment.
%%writefile agent.py from typing import Any, Generator, Optional from uuid import uuid4 import mlflow from databricks.sdk import WorkspaceClient from mlflow.entities import SpanType from mlflow.pyfunc.model import ChatAgent from mlflow.types.agent import ( ChatAgentChunk, ChatAgentMessage, ChatAgentResponse, ChatContext, ) mlflow.openai.autolog() # TODO: Replace with your model serving endpoint LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet" class CustomChatAgent(ChatAgent): def __init__(self): self.workspace_client = WorkspaceClient() self.client = self.workspace_client.serving_endpoints.get_open_ai_client() self.llm_endpoint = LLM_ENDPOINT_NAME def prepare_messages_for_llm(self, messages: list[ChatAgentMessage]) -> list[dict[str, Any]]: """Filter out ChatAgentMessage fields that are not compatible with LLM message formats""" compatible_keys = ["role", "content", "name", "tool_calls", "tool_call_id"] return [ {k: v for k, v in m.model_dump_compat(exclude_none=True).items() if k in compatible_keys} for m in messages ] @mlflow.trace(span_type=SpanType.AGENT) def predict( self, messages: list[ChatAgentMessage], context: Optional[ChatContext] = None, custom_inputs: Optional[dict[str, Any]] = None, ) -> ChatAgentResponse: resp = self.client.chat.completions.create( model=self.llm_endpoint, messages=self.prepare_messages_for_llm(messages), ) custom_output_message = ChatAgentMessage( **{"role": "assistant", "content": "Echoing back custom inputs.", "id": str(uuid4())} ) return ChatAgentResponse( messages=[ ChatAgentMessage(**resp.choices[0].message.to_dict(), id=resp.id), custom_output_message, ], custom_outputs=custom_inputs, ) @mlflow.trace(span_type=SpanType.AGENT) def predict_stream( self, messages: list[ChatAgentMessage], context: Optional[ChatContext] = None, custom_inputs: Optional[dict[str, Any]] = None, ) -> Generator[ChatAgentChunk, None, None]: for chunk in self.client.chat.completions.create( model=self.llm_endpoint, messages=self.prepare_messages_for_llm(messages), stream=True, ): if not chunk.choices or not chunk.choices[0].delta.content: continue yield ChatAgentChunk( delta=ChatAgentMessage( **{ "role": "assistant", "content": chunk.choices[0].delta.content, "id": chunk.id, } ) ) yield ChatAgentChunk( delta=ChatAgentMessage( role="assistant", content="Echoing back custom inputs.", id=str(uuid4()) ), custom_outputs=custom_inputs, ) from mlflow.models import set_model AGENT = CustomChatAgent() set_model(AGENT)
%md ## Test the agent Interact with the agent to test its output. Since we manually traced methods within `ChatAgent`, you can view the trace for each step the agent takes, with any LLM calls made via the OpenAI SDK automatically traced by autologging. Replace this placeholder input with an appropriate domain-specific example for your agent.
Test the agent
Interact with the agent to test its output. Since we manually traced methods within ChatAgent
, you can view the trace for each step the agent takes, with any LLM calls made via the OpenAI SDK automatically traced by autologging.
Replace this placeholder input with an appropriate domain-specific example for your agent.
dbutils.library.restartPython()
from agent import AGENT AGENT.predict( { "messages": [{"role": "user", "content": "What is 5+5?"}], "custom_inputs": {"key": "value"}, } )
for event in AGENT.predict_stream( { "messages": [{"role": "user", "content": "What is 5+5?"}], "custom_inputs": {"key": "value"}, } ): print(event, "-----------\n")
%md ### Log the `agent` as an MLflow model Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).
Log the agent
as an MLflow model
Log the agent as code from the agent.py
file. See MLflow - Models from Code.
import mlflow from mlflow.models.resources import DatabricksServingEndpoint from agent import LLM_ENDPOINT_NAME from pkg_resources import get_distribution with mlflow.start_run(): logged_agent_info = mlflow.pyfunc.log_model( artifact_path="agent", python_model="agent.py", extra_pip_requirements=[f"databricks-connect=={get_distribution('databricks-connect').version}"], resources=[DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)] )
%md ## Pre-deployment agent validation Before registering and deploying the agent, perform pre-deployment checks using the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See Databricks documentation ([AWS](https://docs.databricks.com/en/machine-learning/model-serving/model-serving-debug.html#validate-inputs) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/model-serving-debug#before-model-deployment-validation-checks)).
Pre-deployment agent validation
Before registering and deploying the agent, perform pre-deployment checks using the mlflow.models.predict() API. See Databricks documentation (AWS | Azure).
mlflow.models.predict( model_uri=f"runs:/{logged_agent_info.run_id}/agent", input_data={"messages": [{"role": "user", "content": "Hello!"}]}, env_manager="uv" )
%md ## Register the model to Unity Catalog Before you deploy the agent, you must register the agent to Unity Catalog. - **TODO** Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.
Register the model to Unity Catalog
Before you deploy the agent, you must register the agent to Unity Catalog.
- TODO Update the
catalog
,schema
, andmodel_name
below to register the MLflow model to Unity Catalog.
mlflow.set_registry_uri("databricks-uc") # TODO: define the catalog, schema, and model name for your UC model catalog = "" schema = "" model_name = "" UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}" # register the model to UC uc_registered_model_info = mlflow.register_model(model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME)
%md ## Deploy the agent
Deploy the agent
from databricks import agents agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags={"endpointSource": "docs"})
%md ## Next steps After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/deploy-agent.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/deploy-agent)).