databricks-logo

    agents-obo-example

    (Python)
    Loading...

    On-Behalf-Of-User Authentication Agent

    This notebook demonstrates how to build an Agent with a Vector Search Index that uses on behalf of user authentication. In this notebook you:

    1. Author an agent using the on behalf of user clients
    2. Log the Agent with API Scopes
    3. Deploy the Agent to Model Serving

    NOTE: This notebook uses LangChain, but AI Agent Framework is compatible with any agent authoring framework, including LlamaIndex or pure Python agents written with the OpenAI SDK.

    Prerequisites

    • Address all TODOs in this notebook.
    3
    %pip install -U -qqqq mlflow-skinny[databricks] langgraph==0.3.4 databricks-langchain databricks-agents uv
    dbutils.library.restartPython()

    Define the agent in code

    Below we define our agent code in a single cell, enabling us to easily write it to a local Python file for subsequent logging and deployment using the %%writefile magic command.

    For more examples of tools to add to your agent, see docs.

    %%writefile agent.py
    from typing import Any, Generator, Optional
    
    import mlflow
    from databricks.sdk import WorkspaceClient
    from databricks.sdk.credentials_provider import ModelServingUserCredentials
    from databricks_langchain import (
        ChatDatabricks,
        VectorSearchRetrieverTool,
    )
    from langchain_core.language_models import LanguageModelLike
    from langchain_core.runnables import RunnableConfig, RunnableLambda
    from langgraph.graph import END, StateGraph
    from langgraph.graph.graph import CompiledGraph
    from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
    from mlflow.pyfunc import ChatAgent
    from mlflow.types.agent import (
        ChatAgentChunk,
        ChatAgentMessage,
        ChatAgentResponse,
        ChatContext,
    )
    
    mlflow.langchain.autolog()
    
    ############################################
    # Define your LLM endpoint and system prompt
    ############################################
    LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet"
    llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)
    
    system_prompt = ""
    
    ###############################################################################
    ## Define tools for your agent, enabling it to retrieve data or take actions
    ## beyond text generation
    ## To create and see usage examples of more tools, see
    ## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
    ###############################################################################
    
    
    def create_tools():
        tools = []
        # Use user authenticated client to initialize a vector search retrieval tool
        user_authenticated_client = WorkspaceClient(
            credentials_strategy=ModelServingUserCredentials()
        )
        vector_search_tools = []
        try:
            # TODO fill in fields below
            tool = VectorSearchRetrieverTool(
                index_name="",
                description="",
                tool_name="",
                workspace_client=user_authenticated_client
            )
            vector_search_tools.append(tool)
        except Exception as e:
            print(f"Skipping Vector Search Tool: {e}")
        tools.extend(vector_search_tools)
        return tools
    
    
    #####################
    ## Define agent logic
    #####################
    def create_tool_calling_agent(
        model: LanguageModelLike,
        system_prompt: Optional[str] = None,
    ) -> CompiledGraph:
        tools = create_tools()  # Setup tools
        model = model.bind_tools(tools)
    
        # Define the function that determines which node to go to
        def should_continue(state: ChatAgentState):
            messages = state["messages"]
            last_message = messages[-1]
            # If there are function calls, continue. else, end
            if last_message.get("tool_calls"):
                return "continue"
            else:
                return "end"
    
        if system_prompt:
            preprocessor = RunnableLambda(
                lambda state: [{"role": "system", "content": system_prompt}]
                + state["messages"]
            )
        else:
            preprocessor = RunnableLambda(lambda state: state["messages"])
        model_runnable = preprocessor | model
    
        def call_model(
            state: ChatAgentState,
            config: RunnableConfig,
        ):
            response = model_runnable.invoke(state, config)
    
            return {"messages": [response]}
    
        workflow = StateGraph(ChatAgentState)
    
        workflow.add_node("agent", RunnableLambda(call_model))
        workflow.add_node("tools", ChatAgentToolNode(tools))
    
        workflow.set_entry_point("agent")
        workflow.add_conditional_edges(
            "agent",
            should_continue,
            {
                "continue": "tools",
                "end": END,
            },
        )
        workflow.add_edge("tools", "agent")
    
        return workflow.compile()
    
    
    class LangGraphChatAgent(ChatAgent):
        def predict(
            self,
            messages: list[ChatAgentMessage],
            context: Optional[ChatContext] = None,
            custom_inputs: Optional[dict[str, Any]] = None,
        ) -> ChatAgentResponse:
            # Initialize agent in the predict call here
            agent = create_tool_calling_agent(llm, system_prompt)
            request = {"messages": self._convert_messages_to_dict(messages)}
    
            messages = []
            for event in agent.stream(request, stream_mode="updates"):
                for node_data in event.values():
                    messages.extend(
                        ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                    )
            return ChatAgentResponse(messages=messages)
    
        def predict_stream(
            self,
            messages: list[ChatAgentMessage],
            context: Optional[ChatContext] = None,
            custom_inputs: Optional[dict[str, Any]] = None,
        ) -> Generator[ChatAgentChunk, None, None]:
            agent = create_tool_calling_agent(llm, system_prompt)
            request = {"messages": self._convert_messages_to_dict(messages)}
            for event in agent.stream(request, stream_mode="updates"):
                for node_data in event.values():
                    yield from (
                        ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                    )
    
    
    # Create the agent object, and specify it as the agent object to use when
    # loading the agent back for inference via mlflow.models.set_model()
    AGENT = LangGraphChatAgent()
    mlflow.models.set_model(AGENT)

    Test the agent

    Interact with the agent to test its output. Since this notebook called mlflow.langchain.autolog() you can view the trace for each step the agent takes.

    Replace this placeholder input with an appropriate domain-specific example for your agent.

    dbutils.library.restartPython()
    from agent import AGENT
    
    AGENT.predict({"messages": [{"role": "user", "content": "Hello!"}]})
    for event in AGENT.predict_stream(
        {"messages": [{"role": "user", "content": "What is something cool about databricks"}]}
    ):
        print(event, "-----------\n")

    Log the agent as an MLflow model

    Determine Databricks resources to specify for automatic auth passthrough at deployment time

    • TODO: If your Unity Catalog tool queries a vector search index or leverages external functions, you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs for more details.

    Log the agent as code from the agent.py file. See MLflow - Models from Code.

    # Determine Databricks resources to specify for automatic auth passthrough at deployment time
    import mlflow
    from agent import LLM_ENDPOINT_NAME
    from databricks_langchain import VectorSearchRetrieverTool
    from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
    from unitycatalog.ai.langchain.toolkit import UnityCatalogTool
    from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
    from pkg_resources import get_distribution
    
    # TODO: Manually include underlying resources if needed. 
    resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
    systemAuthPolicy = SystemAuthPolicy(resources=resources)
    
    # TODO: Manually include the required api scopes for this authorization.
    userAuthPolicy = UserAuthPolicy(api_scopes=["serving.serving-endpoints", "vectorsearch.vector-search-endpoints", "vectorsearch.vector-search-indexes"])
    with mlflow.start_run():
        logged_agent_info = mlflow.pyfunc.log_model(
            name="agent",
            python_model="agent.py",
            pip_requirements=[
                f"databricks-connect=={get_distribution('databricks-connect').version}",
                f"mlflow=={get_distribution('mlflow').version}",
                f"databricks-langchain=={get_distribution('databricks-langchain').version}",
                f"langgraph=={get_distribution('langgraph').version}",
            ],
            auth_policy=AuthPolicy(system_auth_policy=systemAuthPolicy, user_auth_policy=userAuthPolicy)
        )
         

    Evaluate the agent with Agent Evaluation

    Use Mosaic AI Agent Evaluation to evalaute the agent's responses based on expected responses and other evaluation criteria. Use the evaluation criteria you specify to guide iterations, using MLflow to track the computed quality metrics. See Databricks documentation (AWS | Azure).

    To evaluate your tool calls, add custom metrics. See Databricks documentation (AWS | Azure).

    import mlflow
    from mlflow.genai.scorers import RelevanceToQuery, Safety
    
    eval_dataset = [
        {
            "inputs": {"messages": [{"role": "user", "content": "What is an LLM?"}]},
            "expected_response": None,
        }
    ]
    
    eval_results = mlflow.genai.evaluate(
        data=eval_dataset,
        predict_fn=lambda messages: AGENT.predict({"messages": messages}),
        scorers=[RelevanceToQuery(), Safety()],
    )
    
    # Review the evaluation results in the MLfLow UI (see console output)

    Perform pre-deployment validation of the agent

    Before registering and deploying the agent, we perform pre-deployment checks via the mlflow.models.predict() API. See documentation for details

    mlflow.models.predict(
        model_uri=f"runs:/{logged_agent_info.run_id}/agent",
        input_data={"messages": [{"role": "user", "content": "Hello!"}]},
        env_manager="uv",
    )

    Register the model to Unity Catalog

    Update the catalog, schema, and model_name below to register the MLflow model to Unity Catalog.

    mlflow.set_registry_uri("databricks-uc")
    
    # TODO: define the catalog, schema, and model name for your UC model
    catalog = ""
    schema = ""
    model_name = ""
    UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"
    
    # register the model to UC
    uc_registered_model_info = mlflow.register_model(
        model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
    )

    Deploy the agent

    from databricks import agents
    agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags = {"endpointSource": "docs"})

    Next steps

    After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See docs for details

    ;