Create a vector search retriever

Preview

This feature is in Public Preview.

Learn how to use Mosaic AI Agent Framework to create retrievers. A retriever is a type of agent tool that finds and returns relevant documents using a Vector Search index. Retrievers are a core component of RAG (Retrieval Augmented Generation) applications.

Requirements

  • MLflow Document is only available on MLflow version 2.14.0 and above.

  • An existing Vector Search Index.

PyFunc retriever example

The following example uses databricks-vectorsearch to create a basic retriever that performs a Vector Search similarity search with filters. It uses MLflow decorators to enable agent tracing.

The retriever function should return a Document type. Use the metadata field in the Document class to add additional attributes to the returned document, like like doc_uri and similarity_score.

Use the following code in the agent module or agent notebook.

import mlflow
import json

from mlflow.entities import Document
from typing import List, Dict, Any
from dataclasses import asdict
from databricks.vector_search.client import VectorSearchClient

class VectorSearchRetriever:
    """
    Class using Databricks Vector Search to retrieve relevant documents.
    """
    def __init__(self):
        self.vector_search_client = VectorSearchClient(disable_notice=True)
        # TODO: Replace this with the list of column names to return in the result when querying Vector Search
        self.columns = ["chunk_id", "text_column", "doc_uri"]
        self.vector_search_index = self.vector_search_client.get_index(
            index_name="catalog.schema.chunked_docs_index"
        )
        mlflow.models.set_retriever_schema(
            name="vector_search",
            primary_key="chunk_id",
            text_column="text_column",
            doc_uri="doc_uri"
        )

    @mlflow.trace(span_type="RETRIEVER", name="vector_search")
    def __call__(
        self,
        query: str,
        filters: Dict[Any, Any] = None,
        score_threshold = None
    ) -> List[Document]:
        """
        Performs vector search to retrieve relevant chunks.
        Args:
            query: Search query.
            filters: Optional filters to apply to the search. Filters must follow the Databricks Vector Search filter spec
            score_threshold: Score threshold to use for the query.

        Returns:
            List of retrieved Documents.
        """

        results = self.vector_search_index.similarity_search(
            query_text=query,
            columns=self.columns,
            filters=filters,
            num_results=5,
            query_type="ann"
        )

        documents = self.convert_vector_search_to_documents(
            results, score_threshold
        )
        return [asdict(doc) for doc in documents]

    @mlflow.trace(span_type="PARSER")
    def convert_vector_search_to_documents(
        self, vs_results, score_threshold
    ) -> List[Document]:

        docs = []
        column_names = [column["name"] for column in vs_results.get("manifest", {}).get("columns", [])]
        result_row_count = vs_results.get("result", {}).get("row_count", 0)

        if result_row_count > 0:
            for item in vs_results["result"]["data_array"]:
                metadata = {}
                score = item[-1]

                if score >= score_threshold:
                    metadata["similarity_score"] = score
                    for i, field in enumerate(item[:-1]):
                        metadata[column_names[i]] = field

                    page_content = metadata.pop("text_column", None)

                    if page_content:
                        doc = Document(
                            page_content=page_content,
                            metadata=metadata
                        )
                        docs.append(doc)

        return docs

To run the retriever, run the following Python code. You can optionally include Vector Search filters in the request to filter results.

retriever = VectorSearchRetriever()
query = "What is Databricks?"
filters={"text_column LIKE": "Databricks"},
results = retriever(query, filters=filters, score_threshold=0.1)

Set retriever schema

To ensure that retrievers are traced properly, call mlflow.models.set_retriever_schema when you define your agent in code. Use set_retriever_schema to map the column names in the returned table to MLflow’s expected fields such as primary_key, text_column, and doc_uri.

# Define the retriever's schema by providing your column names
mlflow.models.set_retriever_schema(
    name="vector_search",
    primary_key="chunk_id",
    text_column="text_column",
    doc_uri="doc_uri"
    # other_columns=["column1", "column2"],
)

Note

The doc_uri column is especially important when evaluating the retriever’s performance. doc_uri is the main identifier for documents returned by the retriever, allowing you to compare them against ground truth evaluation sets. See Evaluation sets.

You can also specify additional columns in your retriever’s schema by providing a list of column names with the other_columns field.

If you have multiple retrievers, you can define multiple schemas by using unique names for each retriever schema.

Trace the retriever

MLflow tracing adds observability by capturing detailed information about your agent’s execution. It provides a way to record the inputs, outputs, and metadata associated with each intermediate step of a request, enabling you to pinpoint the source of bugs and unexpected behaviors easily.

This example uses the @mlflow.trace decorator to create a trace for the retriever and parser. For other options for setting up trace methods, see MLflow Tracing for agents.

The decorator creates a span that starts when the function is invoked and ends when it returns. MLflow automatically records the function’s input and output and any exceptions raised from it.

Note

LangChain, LlamaIndex, and OpenAI library users can use MLflow autologging instead of manually defining traces with the decorator. See Use autologging to add traces to your agents.

...
@mlflow.trace(span_type="RETRIEVER", name="vector_search")
def __call__(self, query: str) -> List[Document]:
  ...

To ensure downstream applications such as Agent Evaluation and the AI Playground render the retriever trace correctly, make sure the decorator meets the following requirements:

  • Use span_type="RETRIEVER" and ensure the function returns List[Document] object. See Retriever spans.

  • The trace name and the retriever_schema name must match to configure the trace correctly.

Filter Vector Search results

You can limit the search scope to a subset of data using a Vector Search filter.

The filters parameter in VectorSearchRetriever defines the filter conditions using the Databricks Vector Search filter specification.

filters = {"text_column LIKE": "Databricks"}

Inside the __call__ method, the filters dictionary is passed directly to the similarity_search function:

results = self.vector_search_index.similarity_search(
    query_text=query,
    columns=self.columns,
    filters=filters,
    num_results=5,
    query_type="ann"
)

After initial filtering, the score_threshold parameter provides additional filtering by setting a minimum similarity score.

if score >= score_threshold:
    metadata["similarity_score"] = score

The final result includes documents that meet the filters and score_threshold conditions.

Retriever example applications

See the genai-cookbook GitHub repository for AI agent examples that use retrievers: