databricks-logo

    enable-gateway-features

    (Python)
    Loading...

    Enable Databricks Mosaic AI Gateway features

    This notebook shows how to enable and use Databricks Mosaic AI Gateway features to manage and govern models from providers, such as OpenAI and Anthropic.

    In this notebook, you use the Model Serving and AI Gateway API to accomplish the following tasks:

    • Create and configure an endpoint for OpenAI GPT-4o-Mini.
    • Enable AI Gateway features including usage tracking, inference tables, guardrails, and rate limits.
    • Set up personally identifiable information (PII) detection for model requests and responses.
    • Implement rate limits for model serving endpoints.
    • Configure multiple models for A/B testing.
    • Enable fallbacks for failed requests.

    If you prefer a low-code experience, you can create an external models endpoint and configure AI Gateway features using the Serving UI (AWS | Azure | GCP).

    2
    %pip install --quiet openai
    %restart_python
    from databricks.sdk import WorkspaceClient
    
    w = WorkspaceClient()
    
    DATABRICKS_HOST = w.config.host
    
    # name of model serving endpoint
    ENDPOINT_NAME = "<endpoint_name>"
    
    # catalog and schema for inference tables
    CATALOG_NAME = "<catalog_name>"
    SCHEMA_NAME = "<schema_name>"
    
    # openai API key in Databricks Secrets
    SECRETS_SCOPE = "<secrets_scope>"
    SECRETS_KEY = "OPENAI_API_KEY"
    
    # if you need to add an OpenAI API key, you can do so with:
    
    # w.secrets.put_secret(scope=SECRETS_SCOPE, key=SECRETS_KEY, string_value='<key_value>')

    Create a model serving endpoint for OpenAI GPT-4o-Mini

    The following creates a model serving endpoint for GPT-4o Mini without AI Gateway enabled. First, you define a helper function for creating and updating the endpoint:

    import requests
    import json
    import time
    from typing import Optional
    
    
    def configure_endpoint(
        name: str,
        databricks_token: str,
        config: dict,
        host: str,
        endpoint_path: Optional[str] = None,
    ):
        base_url = f"{host}/api/2.0/serving-endpoints"
    
        if endpoint_path:
            # Update operation
            api_url = f"{base_url}/{name}/{endpoint_path}"
            method = requests.put
            operation = "Updating"
        else:
            # Create operation
            api_url = base_url
            method = requests.post
            operation = "Creating"
    
        headers = {
            "Authorization": f"Bearer {databricks_token}",
            "Content-Type": "application/json",
        }
    
        print(f"{operation} endpoint...")
        response = method(api_url, headers=headers, json=config)
    
        if response.status_code == 200:
            return response.json()
        else:
            print(
                f"Failed to {operation.lower()} endpoint. Status code: {response.status_code}"
            )
            return response.text

    Next, write a simple configuration to set up the endpoint. See POST /api/2.0/serving-endpoints for API details.

    create_endpoint_request_data = {
        "name": ENDPOINT_NAME,
        "config": {
            "served_entities": [
                {
                    "name": "gpt-4o-mini",
                    "external_model": {
                        "name": "gpt-4o-mini",
                        "provider": "openai",
                        "task": "llm/v1/chat",
                        "openai_config": {
                            "openai_api_key": f"{{{{secrets/{SECRETS_SCOPE}/{SECRETS_KEY}}}}}",
                        },
                    },
                }
            ],
        },
    }
    import time
    
    tmp_token = w.tokens.create(
        comment=f"sdk-{time.time_ns()}", lifetime_seconds=120
    ).token_value
    
    configure_endpoint(
        ENDPOINT_NAME, tmp_token, create_endpoint_request_data, DATABRICKS_HOST
    )

    One of the immediate benefits of using OpenAI models (or models from other providers) using Databricks is that you can immediately query the model using the any of the following methods:

    • Databricks Python SDK
    • OpenAI Python client
    • REST API calls
    • MLflow Deployments SDK
    • Databricks SQL ai_query function

    See the Query foundation models and external models article (AWS | Azure | GCP).

    For example, you can use ai_query to query the model with Databricks SQL.

    %sql
    SELECT
      ai_query(
        "<endpoint name>",
        "What is a mixture of experts model?"
      )

    Add an AI Gateway configuration

    After you set up a model serving endpoint, you can query the OpenAI model using any of the various querying methods accessible in Databricks.

    You can further enrich the model serving endpoint by enabling the Databricks Mosaic AI Gateway, which offers a variety of features for monitoring and managing your endpoint. These features include inference tables, guardrails, and rate limits, among other things.

    To start, the following is a simple configuration that enables inference tables for monitoring endpoint usage. Understanding how the endpoint is being used and how often, helps to determine what usage limits and guardrails are beneficial for your use case.

    gateway_request_data = {
        "usage_tracking_config": {"enabled": True},
        "inference_table_config": {
            "enabled": True,
            "catalog_name": CATALOG_NAME,
            "schema_name": SCHEMA_NAME,
        },
    }
    tmp_token = w.tokens.create(
        comment=f"sdk-{time.time_ns()}", lifetime_seconds=120
    ).token_value
    
    configure_endpoint(
        ENDPOINT_NAME, tmp_token, gateway_request_data, DATABRICKS_HOST, "ai-gateway"
    )

    Query the inference table

    The following displays the inference table that was created when enabled in AI Gateway. Note: For example purposes, a number of queries were run on this endpoint in the AI playground after running the above update to add inference tables, but before querying them.

    spark.sql(
        f"""select request_time, status_code, request, response
            from {CATALOG_NAME}.{SCHEMA_NAME}.`{ENDPOINT_NAME}_payload`
            where status_code=200
            limit 10"""
    ).display()

    You can extract details such as the request messages, response messages, and token counts using SQL:

    query = f"""SELECT
      request_time,
      from_json(
        request,
        'array<struct<messages:array<struct<role:string, content:string>>>>'
      ).messages [0].content AS request_messages,
      from_json(
        response,
        'struct<choices:array<struct<message:struct<role:string, content:string>>>>'
      ).choices [0].message.content AS response_messages,
      from_json(
        response,
        'struct<choices:array<struct<message:struct<role:string, content:string>>>, usage:struct<prompt_tokens:int, completion_tokens:int, total_tokens:int>>'
      ).usage.prompt_tokens AS prompt_tokens,
      from_json(
        response,
        'struct<choices:array<struct<message:struct<role:string, content:string>>>, usage:struct<prompt_tokens:int, completion_tokens:int, total_tokens:int>>'
      ).usage.completion_tokens AS completion_tokens,
      from_json(
        response,
        'struct<choices:array<struct<message:struct<role:string, content:string>>>, usage:struct<prompt_tokens:int, completion_tokens:int, total_tokens:int>>'
      ).usage.total_tokens AS total_tokens
    FROM
      {CATALOG_NAME}.{SCHEMA_NAME}.`{ENDPOINT_NAME}_payload`
    WHERE
      status_code = 200
    LIMIT
      10;"""
    
    spark.sql(query).display()

    Set up AI Guardrails

    Set up PII detection

    Now, the endpoint blocks messages referencing SuperSecretProject. You can also make sure the endpoint doesn't accept requests with or respond with messages containing any PII.

    The following updates the guardrails configuration for pii:

    gateway_request_data.update(
        {
            "guardrails": {
                "input": {
                    "pii": {"behavior": "BLOCK"},
    
                },
                "output": {
                    "pii": {"behavior": "BLOCK"},
                },
            }
        }
    )
    tmp_token = w.tokens.create(
        comment=f"sdk-{time.time_ns()}", lifetime_seconds=120
    ).token_value
    
    configure_endpoint(
        ENDPOINT_NAME, tmp_token, gateway_request_data, DATABRICKS_HOST, "ai-gateway"
    )

    The following tries to prompt the model to work with PII, but returns the message, "Error: PII (Personally Identifiable Information) detected. Please try again.".

    fictional_data = """
    Samantha Lee, slee@fictional-corp.com, (555) 123-4567, Senior Marketing Manager
    Raj Patel, rpatel@imaginary-tech.net, (555) 987-6543, Software Engineer II
    Elena Rodriguez, erodriguez@pretend-company.org, (555) 246-8135, Director of Operations
    """
    
    prompt = f"""
    You are an AI assistant for a company's HR department. Using the employee data provided below, answer the following question:
    
    What is Raj Patel's phone number and email address?
    
    Employee data:
    {fictional_data}
    """
    
    
    client = OpenAI(
        api_key=w.tokens.create(
            comment=f"sdk-{time.time_ns()}", lifetime_seconds=120
        ).token_value,
        base_url=f"{w.config.host}/serving-endpoints",
    )
    
    try:
        response = client.chat.completions.create(
            model=ENDPOINT_NAME,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt},
            ],
            max_tokens=256,
        )
        print(response)
    except Exception as e:
        if "pii_detection': True" in str(e):
            print(
                "Error: PII (Personally Identifiable Information) detected. Please try again."
            )
        else:
            print(f"An error occurred: {e}")
    Error: PII (Personally Identifiable Information) detected. Please try again.

    Add rate limits

    Say you are investigating the inference tables further and you see some steep spikes in usage suggesting a higher-than-expected volume of queries. Extremely high usage could be costly if not monitored and limited.

    query = f"""SELECT
      DATE_TRUNC('minute', request_time) AS minute,
      COUNT(DISTINCT databricks_request_id) AS queries_per_minute
    FROM
      {CATALOG_NAME}.{SCHEMA_NAME}.`{ENDPOINT_NAME}_payload`
    WHERE
      request_time >= CURRENT_TIMESTAMP - INTERVAL 20 HOURS
    GROUP BY
      DATE_TRUNC('minute', request_time)
    ORDER BY
      minute DESC;
    """
    
    spark.sql(query).display()

    You can set a rate limit to prevent excessive queries. In this case, you can set the limit on the endpoint, but it is also possible to set per-user limits.

    gateway_request_data.update(
        {
            "rate_limits": [{"calls": 10, "key": "endpoint", "renewal_period": "minute"}],
        }
    )
    tmp_token = w.tokens.create(
        comment=f"sdk-{time.time_ns()}", lifetime_seconds=120
    ).token_value
    
    configure_endpoint(
        ENDPOINT_NAME, tmp_token, gateway_request_data, DATABRICKS_HOST, "ai-gateway"
    )

    The following shows an example of what the output error looks like when the rate limit is exceeded.

    client = OpenAI(
        api_key=w.tokens.create(
            comment=f"sdk-{time.time_ns()}", lifetime_seconds=120
        ).token_value,
        base_url=f"{w.config.host}/serving-endpoints",
    )
    
    start_time = time.time()
    for i in range(1, 12):
        client.chat.completions.create(
            model=ENDPOINT_NAME,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": f"This is request {i}"},
            ],
            max_tokens=10,
        )
        print(f"Request {i} sent")
    print(f"Total time: {time.time() - start_time:.2f} seconds")
    Request 1 sent Request 2 sent Request 3 sent Request 4 sent Request 5 sent Request 6 sent Request 7 sent Request 8 sent Request 9 sent Request 10 sent
    RateLimitError: Error code: 429 - {'error_code': 'REQUEST_LIMIT_EXCEEDED', 'message': 'REQUEST_LIMIT_EXCEEDED: User defined rate limit(s) exceeded for endpoint: dr-gateway-demo.'}
    File <command-1862781031873971>, line 8 6 start_time = time.time() 7 for i in range(1, 12): ----> 8 client.chat.completions.create( 9 model="dr-gateway-demo", 10 messages=[ 11 {"role": "system", "content": "You are a helpful assistant."}, 12 {"role": "user", "content": f"This is request {i}"}, 13 ], 14 max_tokens=10, 15 ) 16 print(f"Request {i} sent") 17 print(f"Total time: {time.time() - start_time:.2f} seconds")

    File /local_disk0/.ephemeral_nfs/envs/pythonEnv-96f8c053-30f1-4cae-afe7-f8c0461fc561/lib/python3.10/site-packages/openai/_base_client.py:1041, in SyncAPIClient._request(self, cast_to, options, remaining_retries, stream, stream_cls) 1038 err.response.read() 1040 log.debug("Re-raising status error") -> 1041 raise self._make_status_error_from_response(err.response) from None 1043 return self._process_response( 1044 cast_to=cast_to, 1045 options=options, (...) 1049 retries_taken=options.get_max_retries(self.max_retries) - retries, 1050 )

    Add another model

    At some point, you might want to A/B test models from different providers. You can add another OpenAI model to the configuration, like in the following example:

    new_config = {
        "served_entities": [
            {
                "name": "gpt-4o-mini",
                "external_model": {
                    "name": "gpt-4o-mini",
                    "provider": "openai",
                    "task": "llm/v1/chat",
                    "openai_config": {
                        "openai_api_key": f"{{{{secrets/{SECRETS_SCOPE}/{SECRETS_KEY}}}}}",
                    },
                },
            },
            {
                "name": "gpt-4o",
                "external_model": {
                    "name": "gpt-4o",
                    "provider": "openai",
                    "task": "llm/v1/chat",
                    "openai_config": {
                        "openai_api_key": f"{{{{secrets/{SECRETS_SCOPE}/{SECRETS_KEY}}}}}",
                    },
                },
            },
        ],
        "traffic_config": {
            "routes": [
                {"served_model_name": "gpt-4o-mini", "traffic_percentage": 50},
                {"served_model_name": "gpt-4o", "traffic_percentage": 50},
            ]
        },
    }
    tmp_token = w.tokens.create(
        comment=f"sdk-{time.time_ns()}", lifetime_seconds=120
    ).token_value
    
    configure_endpoint(ENDPOINT_NAME, tmp_token, new_config, DATABRICKS_HOST, "config")

    Now, traffic will be split between these two models (you can configure the proportion of traffic going to each model). This enables you to use the inference tables to evaluate the quality of each model and make an informed decision about switching from one model to another.

    Enable fallback models for requests

    For requests on External Models, you can configure a fallback.

    Enabling fallbacks ensures that if a request to one entity fails with a 429 or 5XX error, it will automatically failover to the next entity in the listed order, cycling back to the top if necessary. There is a maximum of 2 fallbacks allowed. Any External Model assigned 0% traffic functions exclusively as a fallback model. The first successful or last failed request attempt is recorded in both the usage tracking system table and the inference table.

    In the following example:

    • The traffic_config field specifies that 50 percent of traffic goes to external_model_1 and the other 50% of the traffic goes to external_model_2.
    • In the ai_gateway section, the fallback_config field specifies that fallbacks are enabled.
    • If a request fails when it is sent to external_model_1 then the request is redirected to the next model listed in the traffic configuration, in this case, external_model_2.
    endpoint_config = {
       "name": endpoint_name,
       "config": {
           # Define your external models as entities
           "served_entities": [
             external_model_1,
             external_model_2
       ],
           "traffic_config": {
             "routes": [
               {
    		# 50% traffic goes to first external model
                 "served_model_name": “external_model_1”,
                 "traffic_percentage": 50
               },
               {
    		# 50% traffic goes to second external model (fallback only)
                 "served_model_name": “external_model_2”,
                 "traffic_percentage": 50
               }
             ]
           }
       },
    # Enable fallbacks (occurs in the order of served entities)
       "ai_gateway": {
         "fallback_config": {"enabled": True}
       }
    }
    
    
    ;