databricks-logo

on-bhealf-of-user-authentication

(Python)
Loading...

On-Behalf-Of-User Authentication Agent

This notebook demonstrates how to build an Agent with a Vector Search Index that uses on behalf of user authentication. In this notebook you:

  1. Author an agent using the on behalf of user clients
  2. Log the Agent with API Scopes
  3. Deploy the Agent to Model Serving

NOTE: This notebook uses LangChain, but AI Agent Framework is compatible with any agent authoring framework, including LlamaIndex or pure Python agents written with the OpenAI SDK.

Prerequisites

  • Address all TODOs in this notebook.
2
%pip install -U -qqqq mlflow langchain langgraph==0.3.4 databricks-langchain pydantic databricks-agents unitycatalog-langchain[databricks] uv
dbutils.library.restartPython()

Define the agent in code

Below we define our agent code in a single cell, enabling us to easily write it to a local Python file for subsequent logging and deployment using the %%writefile magic command.

For more examples of tools to add to your agent, see docs.

4
%%writefile agent.py
from typing import Any, Generator, Optional, Sequence, Union

import mlflow
from databricks_langchain import ChatDatabricks, VectorSearchRetrieverTool
from databricks_langchain.uc_ai import (
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from databricks.sdk import WorkspaceClient
from databricks.sdk.credentials_provider import ModelServingUserCredentials
mlflow.langchain.autolog()

############################################
# Define your LLM endpoint and system prompt
############################################
LLM_ENDPOINT_NAME = "databricks-meta-llama-3-3-70b-instruct"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

system_prompt = ""

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
###############################################################################

def create_tools():
    tools = []
    # Use user authenticated client to initialize a vector search retrieval tool
    user_authenticated_client = WorkspaceClient(credentials_strategy = ModelServingUserCredentials())
    vector_search_tools = []
    try:
        tool = VectorSearchRetrieverTool(
            index_name= <FILL INDEX NAME>, # TODO
            description= <FILL DESCRIPTION>, # TODO
            tool_name= <FILL TOOL NAME>, # TODO
            workspace_client=user_authenticated_client
        )
        vector_search_tools.append(tool)
    except Exception as e:
        print(f"Skipping Vector Search Tool: {e}")
    tools.extend(vector_search_tools)
    return tools

#####################
## Define agent logic
#####################
def create_tool_calling_agent(
    model: LanguageModelLike,
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    tools = create_tools() # Setup tools
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        # Initialize agent in the predict call here
        agent = create_tool_calling_agent(llm, system_prompt) 
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        for event in agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        agent = create_tool_calling_agent(llm, system_prompt)
        request = {"messages": self._convert_messages_to_dict(messages)}
        for event in agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                )


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
AGENT = LangGraphChatAgent()
mlflow.models.set_model(AGENT)

Test the agent

Interact with the agent to test its output. Since this notebook called mlflow.langchain.autolog() you can view the trace for each step the agent takes.

Replace this placeholder input with an appropriate domain-specific example for your agent.

6
dbutils.library.restartPython()
7
from agent import AGENT

AGENT.predict({"messages": [{"role": "user", "content": "Hello!"}]})
8
for event in AGENT.predict_stream(
    {"messages": [{"role": "user", "content": "What is 5+5 in python"}]}
):
    print(event, "-----------\n")

Log the agent as an MLflow model

Determine Databricks resources to specify for automatic auth passthrough at deployment time

  • TODO: If your Unity Catalog tool queries a vector search index or leverages external functions, you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs for more details.

Log the agent as code from the agent.py file. See MLflow - Models from Code.

10
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool
from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy

# TODO: Manually include underlying resources if needed. 
resources = [DatabricksServingEndpoint(endpoint_name="databricks-meta-llama-3-3-70b-instruct")]
systemAuthPolicy = SystemAuthPolicy(resources=resources)

# TODO: Manually include the required api scopes for this authorization.
userAuthPolicy = UserAuthPolicy(api_scopes=["serving.serving-endpoints", "vectorsearch.vector-search-endpoints", "vectorsearch.vector-search-indexes"])

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "Hello!"
        }
    ]
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="agent",
        python_model="agent.py",
        input_example=input_example,
        pip_requirements=[
            "mlflow",
            "langchain",
            "langgraph",
            "databricks-langchain",
            "unitycatalog-langchain[databricks]",
            "pydantic",
            "databricks-sdk"
        ],
        auth_policy=AuthPolicy(system_auth_policy=systemAuthPolicy, user_auth_policy=userAuthPolicy)
    )
     

Evaluate the agent with Agent Evaluation

You can edit the requests or expected responses in your evaluation dataset and run evaluation as you iterate your agent, leveraging mlflow to track the computed quality metrics.

To evaluate your tool calls, try adding custom metrics.

12
import pandas as pd

eval_examples = [
    {
        "request": {
            "messages": [
                {
                    "role": "user",
                    "content": "What is an LLM agent?"
                }
            ]
        },
        "expected_response": None
    }
]

eval_dataset = pd.DataFrame(eval_examples)
display(eval_dataset)

13
import mlflow

with mlflow.start_run(run_id=logged_agent_info.run_id):
    eval_results = mlflow.evaluate(
        f"runs:/{logged_agent_info.run_id}/agent",
        data=eval_dataset,  # Your evaluation dataset
        model_type="databricks-agent",  # Enable Mosaic AI Agent Evaluation
    )

# Review the evaluation results in the MLFLow UI (see console output), or access them in place:
display(eval_results.tables['eval_results'])

Perform pre-deployment validation of the agent

Before registering and deploying the agent, we perform pre-deployment checks via the mlflow.models.predict() API. See documentation for details

15
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={"messages": [{"role": "user", "content": "Hello!"}]},
    env_manager="uv",
)

Register the model to Unity Catalog

Update the catalog, schema, and model_name below to register the MLflow model to Unity Catalog.

17
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = ""
schema = ""
model_name = ""
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

Deploy the agent

19
from databricks import agents
agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags = {"endpointSource": "playground"})

Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See docs for details

;