databricks-logo

    langgraph-tool-calling-agent

    (Python)
    Loading...

    Mosaic AI Agent Framework: Author and deploy a tool-calling LangGraph agent

    This notebook demonstrates how to author a LangGraph agent that's compatible with Mosaic AI Agent Framework features. In this notebook you learn to:

    • Author a tool-calling LangGraph agent wrapped with ChatAgent
    • Manually test the agent's output
    • Evaluate the agent using Mosaic AI Agent Evaluation
    • Log and deploy the agent

    To learn more about authoring an agent using Mosaic AI Agent Framework, see Databricks documentation (AWS | Azure).

    Prerequisites

    • Address all TODOs in this notebook.
    3
    %pip install -U -qqqq mlflow-skinny[databricks] databricks-langchain databricks-agents uv langgraph==0.3.4
    dbutils.library.restartPython()

    Define the agent in code

    Define the agent code in a single cell below. This lets you easily write the agent code to a local Python file, using the %%writefile magic command, for subsequent logging and deployment.

    Agent tools

    This agent code adds the built-in Unity Catalog function system.ai.python_exec to the agent. The agent code also includes commented-out sample code for adding a vector search index to perform unstructured data retrieval.

    For more examples of tools to add to your agent, see Databricks documentation (AWS | Azure)

    Wrap the LangGraph agent using the ChatAgent interface

    For compatibility with Databricks AI features, the LangGraphChatAgent class implements the ChatAgent interface to wrap the LangGraph agent. This example uses the provided convenience APIs ChatAgentState and ChatAgentToolNode for ease of use.

    Databricks recommends using ChatAgent as it simplifies authoring multi-turn conversational agents using an open source standard. See MLflow's ChatAgent documentation.

    %%writefile agent.py
    from typing import Any, Generator, Optional, Sequence, Union
    
    import mlflow
    from databricks_langchain import (
        ChatDatabricks,
        UCFunctionToolkit,
        VectorSearchRetrieverTool,
    )
    from langchain_core.language_models import LanguageModelLike
    from langchain_core.runnables import RunnableConfig, RunnableLambda
    from langchain_core.tools import BaseTool
    from langgraph.graph import END, StateGraph
    from langgraph.graph.graph import CompiledGraph
    from langgraph.graph.state import CompiledStateGraph
    from langgraph.prebuilt.tool_node import ToolNode
    from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
    from mlflow.pyfunc import ChatAgent
    from mlflow.types.agent import (
        ChatAgentChunk,
        ChatAgentMessage,
        ChatAgentResponse,
        ChatContext,
    )
    ############################################
    # Define your LLM endpoint and system prompt
    ############################################
    # TODO: Replace with your model serving endpoint
    LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet"
    llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)
    
    # TODO: Update with your system prompt
    system_prompt = ""
    
    ###############################################################################
    ## Define tools for your agent, enabling it to retrieve data or take actions
    ## beyond text generation
    ## To create and see usage examples of more tools, see
    ## https://docs.databricks.com/en/generative-ai/agent-framework/agent-tool.html
    ###############################################################################
    tools = []
    
    # You can use UDFs in Unity Catalog as agent tools
    # Below, we add the `system.ai.python_exec` UDF, which provides
    # a python code interpreter tool to our agent
    # You can also add local LangChain python tools. See https://python.langchain.com/docs/concepts/tools
    
    # TODO: Add additional tools
    uc_tool_names = ["system.ai.python_exec"]
    uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
    tools.extend(uc_toolkit.tools)
    
    # Use Databricks vector search indexes as tools
    # See https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html
    # for details
    
    # TODO: Add vector search indexes
    # vector_search_tools = [
    #         VectorSearchRetrieverTool(
    #         index_name="",
    #         # filters="..."
    #     )
    # ]
    # tools.extend(vector_search_tools)
    
    #####################
    ## Define agent logic
    #####################
    
    
    def create_tool_calling_agent(
        model: LanguageModelLike,
        tools: Union[ToolNode, Sequence[BaseTool]],
        system_prompt: Optional[str] = None,
    ) -> CompiledGraph:
        model = model.bind_tools(tools)
    
        # Define the function that determines which node to go to
        def should_continue(state: ChatAgentState):
            messages = state["messages"]
            last_message = messages[-1]
            # If there are function calls, continue. else, end
            if last_message.get("tool_calls"):
                return "continue"
            else:
                return "end"
    
        if system_prompt:
            preprocessor = RunnableLambda(
                lambda state: [{"role": "system", "content": system_prompt}]
                + state["messages"]
            )
        else:
            preprocessor = RunnableLambda(lambda state: state["messages"])
        model_runnable = preprocessor | model
    
        def call_model(
            state: ChatAgentState,
            config: RunnableConfig,
        ):
            response = model_runnable.invoke(state, config)
    
            return {"messages": [response]}
    
        workflow = StateGraph(ChatAgentState)
    
        workflow.add_node("agent", RunnableLambda(call_model))
        workflow.add_node("tools", ChatAgentToolNode(tools))
    
        workflow.set_entry_point("agent")
        workflow.add_conditional_edges(
            "agent",
            should_continue,
            {
                "continue": "tools",
                "end": END,
            },
        )
        workflow.add_edge("tools", "agent")
    
        return workflow.compile()
    
    
    class LangGraphChatAgent(ChatAgent):
        def __init__(self, agent: CompiledStateGraph):
            self.agent = agent
    
        def predict(
            self,
            messages: list[ChatAgentMessage],
            context: Optional[ChatContext] = None,
            custom_inputs: Optional[dict[str, Any]] = None,
        ) -> ChatAgentResponse:
            request = {"messages": self._convert_messages_to_dict(messages)}
    
            messages = []
            for event in self.agent.stream(request, stream_mode="updates"):
                for node_data in event.values():
                    messages.extend(
                        ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                    )
            return ChatAgentResponse(messages=messages)
    
        def predict_stream(
            self,
            messages: list[ChatAgentMessage],
            context: Optional[ChatContext] = None,
            custom_inputs: Optional[dict[str, Any]] = None,
        ) -> Generator[ChatAgentChunk, None, None]:
            request = {"messages": self._convert_messages_to_dict(messages)}
            for event in self.agent.stream(request, stream_mode="updates"):
                for node_data in event.values():
                    yield from (
                        ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                    )
    
    
    # Create the agent object, and specify it as the agent object to use when
    # loading the agent back for inference via mlflow.models.set_model()
    mlflow.langchain.autolog()
    agent = create_tool_calling_agent(llm, tools, system_prompt)
    AGENT = LangGraphChatAgent(agent)
    mlflow.models.set_model(AGENT)

    Test the agent

    Interact with the agent to test its output and tool-calling abilities. Since this notebook called mlflow.langchain.autolog(), you can view the trace for each step the agent takes.

    Replace this placeholder input with an appropriate domain-specific example for your agent.

    dbutils.library.restartPython()
    from agent import AGENT
    
    AGENT.predict({"messages": [{"role": "user", "content": "Hello!"}]})
    for event in AGENT.predict_stream(
        {"messages": [{"role": "user", "content": "What is 5+5 in python"}]}
    ):
        print(event, "-----------\n")

    Log the agent as an MLflow model

    Log the agent as code from the agent.py file. See MLflow - Models from Code.

    Enable automatic authentication for Databricks resources

    For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint.

    To enable automatic authentication, specify the dependent Databricks resources when calling mlflow.pyfunc.log_model().

    • TODO: If your Unity Catalog tool queries a vector search index or leverages external functions, you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs (AWS | Azure).
    import mlflow
    from agent import tools, LLM_ENDPOINT_NAME
    from databricks_langchain import VectorSearchRetrieverTool
    from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
    from unitycatalog.ai.langchain.toolkit import UnityCatalogTool
    from pkg_resources import get_distribution
    
    resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
    for tool in tools:
        if isinstance(tool, VectorSearchRetrieverTool):
            resources.extend(tool.resources)
        elif isinstance(tool, UnityCatalogTool):
            resources.append(DatabricksFunction(function_name=tool.uc_function_name))
    
    
    with mlflow.start_run():
        logged_agent_info = mlflow.pyfunc.log_model(
            name="agent",
            python_model="agent.py",
            resources=resources,
            pip_requirements=[
                f"databricks-connect=={get_distribution('databricks-connect').version}",
                f"mlflow=={get_distribution('mlflow').version}",
                f"databricks-langchain=={get_distribution('databricks-langchain').version}",
                f"langgraph=={get_distribution('langgraph').version}",
            ],
        )

    Evaluate the agent with Agent Evaluation

    Use Mosaic AI Agent Evaluation to evalaute the agent's responses based on expected responses and other evaluation criteria. Use the evaluation criteria you specify to guide iterations, using MLflow to track the computed quality metrics. See Databricks documentation (AWS | Azure).

    To evaluate your tool calls, add custom metrics. See Databricks documentation (AWS | Azure).

    import mlflow
    from mlflow.genai.scorers import RelevanceToQuery, Safety
    
    eval_dataset = [
        {
            "inputs": {"messages": [{"role": "user", "content": "What is an LLM?"}]},
            "expected_response": None,
        }
    ]
    
    eval_results = mlflow.genai.evaluate(
        data=eval_dataset,
        predict_fn=lambda messages: AGENT.predict({"messages": messages}),
        scorers=[RelevanceToQuery(), Safety()],
    )
    
    # Review the evaluation results in the MLfLow UI (see console output)

    Pre-deployment agent validation

    Before registering and deploying the agent, perform pre-deployment checks using the mlflow.models.predict() API. See Databricks documentation (AWS | Azure).

    mlflow.models.predict(
        model_uri=f"runs:/{logged_agent_info.run_id}/agent",
        input_data={"messages": [{"role": "user", "content": "Hello!"}]},
        env_manager="uv",
    )

    Register the model to Unity Catalog

    Before you deploy the agent, you must register the agent to Unity Catalog.

    • TODO Update the catalog, schema, and model_name below to register the MLflow model to Unity Catalog.
    mlflow.set_registry_uri("databricks-uc")
    
    # TODO: define the catalog, schema, and model name for your UC model
    catalog = ""
    schema = ""
    model_name = ""
    UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"
    
    # register the model to UC
    uc_registered_model_info = mlflow.register_model(
        model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
    )

    Deploy the agent

    from databricks import agents
    agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags = {"endpointSource": "docs"})

    Next steps

    After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See Databricks documentation (AWS | Azure).

    ;