%md [‹ Back to Table of Contents](index.html)
%md # Mosaic AI Agent Framework: Author a custom schema tool-calling LangGraph agent This notebook demonstrates how to author a LangGraph agent that's compatible with Mosaic AI Agent Framework features. In this notebook you learn to: - Author a tool-calling LangGraph agent wrapped with `ChatAgent` with custom inputs and outputs - Manually test the agent's output - Log and deploy the agent To learn more about authoring an agent using Mosaic AI Agent Framework, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/author-agent) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/create-chat-model)). ## Prerequisites - Address all `TODO`s in this notebook.
Mosaic AI Agent Framework: Author a custom schema tool-calling LangGraph agent
This notebook demonstrates how to author a LangGraph agent that's compatible with Mosaic AI Agent Framework features. In this notebook you learn to:
- Author a tool-calling LangGraph agent wrapped with
ChatAgent
with custom inputs and outputs - Manually test the agent's output
- Log and deploy the agent
To learn more about authoring an agent using Mosaic AI Agent Framework, see Databricks documentation (AWS | Azure).
Prerequisites
- Address all
TODO
s in this notebook.
%pip install -U -qqqq mlflow databricks-langchain databricks-agents uv langgraph==0.3.4 dbutils.library.restartPython()
%md ## Define the agent in code Define the agent code in a single cell below. This lets you easily write the agent code to a local Python file, using the `%%writefile` magic command, for subsequent logging and deployment. #### Agent tools This agent code adds the built-in Unity Catalog function `system.ai.python_exec` to the agent. The agent code also includes commented-out sample code for adding a vector search index to perform unstructured data retrieval. For more examples of tools to add to your agent, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-tool) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-framework/agent-tool)) #### Wrap the LangGraph agent using the `ChatAgent` interface For compatibility with Databricks AI features, the `LangGraphChatAgent` class implements the `ChatAgent` interface to wrap the LangGraph agent. This example uses the provided convenience APIs [`ChatAgentState`](https://mlflow.org/docs/latest/python_api/mlflow.langchain.html#mlflow.langchain.chat_agent_langgraph.ChatAgentState) and [`ChatAgentToolNode`](https://mlflow.org/docs/latest/python_api/mlflow.langchain.html#mlflow.langchain.chat_agent_langgraph.ChatAgentToolNode) for ease of use. Databricks recommends using `ChatAgent` as it simplifies authoring multi-turn conversational agents using an open source standard. See MLflow's [ChatAgent documentation](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ChatAgent). #### Custom inputs and outputs The agent is designed to handle custom inputs and outputs using the `predict` methods and the `add_custom_outputs` function. - In the `predict` and `predict_stream` methods of the `LangGraphChatAgent` class: - Custom inputs are passed as an optional parameter and included in the request dictionary. - Custom outputs are captured from the agent's response and added to the `ChatAgentResponse` object. - In the `add_custom_outputs` function: - This function is added as a node in the agent's workflow to append custom outputs to the state before returning the final response.
Define the agent in code
Define the agent code in a single cell below. This lets you easily write the agent code to a local Python file, using the %%writefile
magic command, for subsequent logging and deployment.
Agent tools
This agent code adds the built-in Unity Catalog function system.ai.python_exec
to the agent. The agent code also includes commented-out sample code for adding a vector search index to perform unstructured data retrieval.
For more examples of tools to add to your agent, see Databricks documentation (AWS | Azure)
Wrap the LangGraph agent using the ChatAgent
interface
For compatibility with Databricks AI features, the LangGraphChatAgent
class implements the ChatAgent
interface to wrap the LangGraph agent. This example uses the provided convenience APIs ChatAgentState
and ChatAgentToolNode
for ease of use.
Databricks recommends using ChatAgent
as it simplifies authoring multi-turn conversational agents using an open source standard. See MLflow's ChatAgent documentation.
Custom inputs and outputs
The agent is designed to handle custom inputs and outputs using the predict
methods and the add_custom_outputs
function.
In the
predict
andpredict_stream
methods of theLangGraphChatAgent
class:- Custom inputs are passed as an optional parameter and included in the request dictionary.
- Custom outputs are captured from the agent's response and added to the
ChatAgentResponse
object.
In the
add_custom_outputs
function:- This function is added as a node in the agent's workflow to append custom outputs to the state before returning the final response.
%%writefile agent.py from typing import Any, Generator, Optional, Sequence, Union import mlflow from databricks_langchain import ( ChatDatabricks, UCFunctionToolkit, VectorSearchRetrieverTool, ) from langchain_core.language_models import LanguageModelLike from langchain_core.runnables import RunnableConfig, RunnableLambda from langchain_core.tools import BaseTool from langgraph.graph import END, StateGraph from langgraph.graph.graph import CompiledGraph from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt.tool_node import ToolNode from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode from mlflow.pyfunc import ChatAgent from mlflow.types.agent import ( ChatAgentChunk, ChatAgentMessage, ChatAgentResponse, ChatContext, ) ############################################ # Define your LLM endpoint and system prompt ############################################ # TODO: Replace with your model serving endpoint LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet" llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME) # TODO: Update with your system prompt system_prompt = "" ############################################################################### ## Define tools for your agent, enabling it to retrieve data or take actions ## beyond text generation ## To create and see usage examples of more tools, see ## https://docs.databricks.com/en/generative-ai/agent-framework/agent-tool.html ############################################################################### tools = [] # You can use UDFs in Unity Catalog as agent tools # Below, we add the `system.ai.python_exec` UDF, which provides # a python code interpreter tool to our agent # You can also add local LangChain python tools. See https://python.langchain.com/docs/concepts/tools # TODO: Add additional tools uc_tool_names = ["system.ai.python_exec"] uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names) tools.extend(uc_toolkit.tools) # Use Databricks vector search indexes as tools # See https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html#locally-develop-vector-search-retriever-tools-with-ai-bridge # TODO: Add vector search indexes # vector_search_tools = [ # VectorSearchRetrieverTool( # index_name="", # # filters="..." # ) # ] # tools.extend(vector_search_tools) ##################### ## Define agent logic ##################### def create_tool_calling_agent( model: LanguageModelLike, tools: Union[ToolNode, Sequence[BaseTool]], agent_prompt: Optional[str] = None, ) -> CompiledGraph: model = model.bind_tools(tools) # Define the function that determines which node to go to def should_continue(state: ChatAgentState): messages = state["messages"] last_message = messages[-1] # If there are function calls, continue. else, end if last_message.get("tool_calls"): return "continue" else: return "end" if agent_prompt: preprocessor = RunnableLambda( lambda state: [{"role": "system", "content": agent_prompt}] + state["messages"] ) else: preprocessor = RunnableLambda(lambda state: state["messages"]) model_runnable = preprocessor | model def call_model( state: ChatAgentState, config: RunnableConfig, ): response = model_runnable.invoke(state, config) return {"messages": [response]} def add_custom_outputs(state: ChatAgentState): # TODO: Return extra content with the custom_outputs key before returning return { "messages": [{"role": "assistant", "content": "Adding custom outputs"}], "custom_outputs": { **(state.get("custom_outputs") or {}), **(state.get("custom_inputs") or {}), "key": "value", }, } workflow = StateGraph(ChatAgentState) workflow.add_node("agent", RunnableLambda(call_model)) workflow.add_node("tools", ChatAgentToolNode(tools)) workflow.add_node("add_custom_outputs", RunnableLambda(add_custom_outputs)) workflow.set_entry_point("agent") workflow.add_conditional_edges( "agent", should_continue, { "continue": "tools", "end": "add_custom_outputs", }, ) workflow.add_edge("tools", "agent") workflow.add_edge("add_custom_outputs", END) return workflow.compile() class LangGraphChatAgent(ChatAgent): def __init__(self, agent: CompiledStateGraph): self.agent = agent def predict( self, messages: list[ChatAgentMessage], context: Optional[ChatContext] = None, custom_inputs: Optional[dict[str, Any]] = None, ) -> ChatAgentResponse: # TODO: Use context and custom_inputs to alter the behavior of the agent request = { "messages": self._convert_messages_to_dict(messages), **({"custom_inputs": custom_inputs} if custom_inputs else {}), **({"context": context.model_dump_compat()} if context else {}), } messages = [] custom_outputs = None for event in self.agent.stream(request, stream_mode="updates"): for node_data in event.values(): if not node_data: continue for msg in node_data.get("messages", []): messages.append(ChatAgentMessage(**msg)) if "custom_outputs" in node_data: custom_outputs = node_data["custom_outputs"] return ChatAgentResponse(messages=messages, custom_outputs=custom_outputs) def predict_stream( self, messages: list[ChatAgentMessage], context: Optional[ChatContext] = None, custom_inputs: Optional[dict[str, Any]] = None, ) -> Generator[ChatAgentChunk, None, None]: # TODO: Use context and custom_inputs to alter the behavior of the agent request = { "messages": self._convert_messages_to_dict(messages), **({"custom_inputs": custom_inputs} if custom_inputs else {}), **({"context": context.model_dump_compat()} if context else {}), } last_message = None last_custom_outputs = None for event in self.agent.stream(request, stream_mode="updates"): for node_data in event.values(): if not node_data: continue messages = node_data.get("messages", []) custom_outputs = node_data.get("custom_outputs") for message in messages: if last_message: yield ChatAgentChunk(delta=last_message) last_message = message if custom_outputs: last_custom_outputs = custom_outputs if last_message: yield ChatAgentChunk(delta=last_message, custom_outputs=last_custom_outputs) # Create the agent object, and specify it as the agent object to use when # loading the agent back for inference via mlflow.models.set_model() mlflow.langchain.autolog() agent = create_tool_calling_agent(llm, tools, system_prompt) AGENT = LangGraphChatAgent(agent) mlflow.models.set_model(AGENT)
%md ## Test the agent Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes. Replace this placeholder input with an appropriate domain-specific example for your agent.
Test the agent
Interact with the agent to test its output. Since this notebook called mlflow.langchain.autolog()
you can view the trace for each step the agent takes.
Replace this placeholder input with an appropriate domain-specific example for your agent.
dbutils.library.restartPython()
from agent import AGENT AGENT.predict( { "messages": [{"role": "user", "content": "Hello!"}], "custom_inputs": {"key": "value"}, } )
for event in AGENT.predict_stream( { "messages": [{"role": "user", "content": "What is 5+5 in python"}], "custom_inputs": {"key": "value"}, } ): print(event, "-----------\n")
%md ## Log the agent as an MLflow model Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code). ### Enable automatic authentication for Databricks resources For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint. To enable automatic authentication, specify the dependent Databricks resources when calling `mlflow.pyfunc.log_model().` - **TODO**: If your Unity Catalog tool queries a [vector search index](docs link) or leverages [external functions](docs link), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs ([AWS](https://docs.databricks.com/generative-ai/agent-framework/log-agent.html#specify-resources-for-automatic-authentication-passthrough) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/log-agent#resources)).
Log the agent as an MLflow model
Log the agent as code from the agent.py
file. See MLflow - Models from Code.
Enable automatic authentication for Databricks resources
For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint.
To enable automatic authentication, specify the dependent Databricks resources when calling mlflow.pyfunc.log_model().
- TODO: If your Unity Catalog tool queries a vector search index or leverages external functions, you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs (AWS | Azure).
import mlflow from agent import tools, LLM_ENDPOINT_NAME from databricks_langchain import VectorSearchRetrieverTool from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint from unitycatalog.ai.langchain.toolkit import UnityCatalogTool from pkg_resources import get_distribution resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)] for tool in tools: if isinstance(tool, VectorSearchRetrieverTool): resources.extend(tool.resources) elif isinstance(tool, UnityCatalogTool): resources.append(DatabricksFunction(function_name=tool.uc_function_name)) with mlflow.start_run(): logged_agent_info = mlflow.pyfunc.log_model( artifact_path="agent", python_model="agent.py", extra_pip_requirements=[f"databricks-connect=={get_distribution('databricks-connect').version}"], resources=resources, )
%md ## Pre-deployment agent validation Before registering and deploying the agent, perform pre-deployment checks using the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See Databricks documentation ([AWS](https://docs.databricks.com/en/machine-learning/model-serving/model-serving-debug.html#validate-inputs) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/model-serving-debug#before-model-deployment-validation-checks)).
Pre-deployment agent validation
Before registering and deploying the agent, perform pre-deployment checks using the mlflow.models.predict() API. See Databricks documentation (AWS | Azure).
mlflow.models.predict( model_uri=f"runs:/{logged_agent_info.run_id}/agent", input_data={"messages": [{"role": "user", "content": "Hello!"}]}, env_manager="uv", )
%md ## Register the model to Unity Catalog Before you deploy the agent, you must register the agent to Unity Catalog. - **TODO** Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.
Register the model to Unity Catalog
Before you deploy the agent, you must register the agent to Unity Catalog.
- TODO Update the
catalog
,schema
, andmodel_name
below to register the MLflow model to Unity Catalog.
mlflow.set_registry_uri("databricks-uc") # TODO: define the catalog, schema, and model name for your UC model catalog = "" schema = "" model_name = "" UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}" # register the model to UC uc_registered_model_info = mlflow.register_model( model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME )
%md ## Deploy the agent
Deploy the agent
from databricks import agents agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags = {"endpointSource": "docs"})
%md ## Next steps After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/deploy-agent.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/deploy-agent)).