Adicionar traços a seus agentes
Importante
Esse recurso está em Prévia Pública.
Este artigo mostra como adicionar rastreamentos aos seus agentes usando o Fluent e o MLflowClient APIs disponibilizados com o MLflow Tracing.
Observação
Para obter uma referência detalhada da API e exemplos de código para o MLflow Tracing, consulte a documentação do MLflow.
Use o registro automático para adicionar rastreamentos aos seus agentes
Se estiver usando uma biblioteca GenAI que tenha suporte para rastreamento (como LangChain, LlamaIndex ou OpenAI), o senhor pode ativar o autologging do MLflow para a integração da biblioteca para ativar o rastreamento. Por exemplo, use mlflow.langchain.autolog()
para adicionar rastreamentos automaticamente ao seu agente baseado em Langchain.
Observação
A partir de Databricks Runtime 15.4 LTS ML, o rastreamento MLflow é ativado por default no Notebook. Para desativar o rastreamento, por exemplo, com LangChain, o senhor pode executar mlflow.langchain.autolog(log_traces=False)
no Notebook.
mlflow.langchain.autolog()
MLflow oferece suporte a biblioteca adicional para autologação de rastreamento. Consulte a documentação de rastreamento do siteMLflow para obter uma lista completa de bibliotecas integradas.
Use as APIs do Fluent para adicionar manualmente traços ao seu agente
A seguir, um exemplo rápido que usa as APIs do Fluent: mlflow.trace
e mlflow.start_span
para adicionar traços ao quickstart-agent
. Isso é recomendado para modelos PyFunc.
import mlflow
from mlflow.deployments import get_deploy_client
class QAChain(mlflow.pyfunc.PythonModel):
def __init__(self):
self.client = get_deploy_client("databricks")
@mlflow.trace(name="quickstart-agent")
def predict(self, model_input, system_prompt, params):
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": model_input[0]["query"]
}
]
traced_predict = mlflow.trace(self.client.predict)
output = traced_predict(
endpoint=params["model_name"],
inputs={
"temperature": params["temperature"],
"max_tokens": params["max_tokens"],
"messages": messages,
},
)
with mlflow.start_span(name="_final_answer") as span:
# Initiate another span generation
span.set_inputs({"query": model_input[0]["query"]})
answer = output["choices"][0]["message"]["content"]
span.set_outputs({"generated_text": answer})
# Attributes computed at runtime can be set using the set_attributes() method.
span.set_attributes({
"model_name": params["model_name"],
"prompt_tokens": output["usage"]["prompt_tokens"],
"completion_tokens": output["usage"]["completion_tokens"],
"total_tokens": output["usage"]["total_tokens"]
})
return answer
Realizar inferência
Depois de instrumentar o código, o senhor pode executar a função como faria normalmente. A seguir, continuamos o exemplo com a função predict()
da seção anterior. Os traços são exibidos automaticamente quando o senhor executa o método de invocação, predict()
.
SYSTEM_PROMPT = """
You are an assistant for Databricks users. You are answering python, coding, SQL, data engineering, spark, data science, DW and platform, API or infrastructure administration question related to Databricks. If the question is not related to one of these topics, kindly decline to answer. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Use the following pieces of context to answer the question at the end:
"""
model = QAChain()
prediction = model.predict(
[
{"query": "What is in MLflow 5.0"},
],
SYSTEM_PROMPT,
{
# Using Databricks Foundation Model for easier testing, feel free to replace it.
"model_name": "databricks-dbrx-instruct",
"temperature": 0.1,
"max_tokens": 1000,
}
)
APIs fluentes
As APIs do Fluent no MLflow constroem automaticamente a hierarquia de rastreamento com base em onde e quando o código é executado. As seções a seguir descrevem a tarefa suportada usando o MLflow Tracing Fluent APIs.
Decore sua função
O senhor pode decorar sua função com o decorador @mlflow.trace
para criar um intervalo para o escopo da função decorada. O intervalo começa quando a função é chamada e termina quando ela retorna. O MLflow registra automaticamente a entrada e a saída da função, bem como todas as exceções geradas pela função. Por exemplo, a execução do código a seguir criará um intervalo com o nome "my_function", capturando os argumentos de entrada x e y, bem como a saída da função.
@mlflow.trace(name="agent", span_type="TYPE", attributes={"key": "value"})
def my_function(x, y):
return x + y
Usar o gerenciador de contexto de rastreamento
Se quiser criar uma extensão para um bloco arbitrário de código, não apenas uma função, o senhor pode usar mlflow.start_span()
como um gerenciador de contexto que envolve o bloco de código. O intervalo começa quando o contexto é inserido e termina quando o contexto é encerrado. A entrada e as saídas do span devem ser fornecidas manualmente por meio de métodos setter do objeto span que é gerado pelo gerenciador de contexto.
with mlflow.start_span("my_span") as span:
span.set_inputs({"x": x, "y": y})
result = x + y
span.set_outputs(result)
span.set_attribute("key", "value")
Envolver uma função externa
A função mlflow.trace
pode ser usada como um invólucro para rastrear uma função de sua escolha. Isso é útil quando o senhor deseja rastrear funções importadas de uma biblioteca externa. Ele gera o mesmo intervalo que o senhor obteria ao decorar essa função.
from sklearn.metrics import accuracy_score
y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]
traced_accuracy_score = mlflow.trace(accuracy_score)
traced_accuracy_score(y_true, y_pred)
APIs de cliente do MLflow
MlflowClient
expõe o APIs granular e seguro para threads para começar e terminar rastreamentos, gerenciar intervalos e definir campos de intervalo. Ele oferece controle total do ciclo de vida e da estrutura do rastreamento. Essas APIs são úteis quando as APIs do Fluent não são suficientes para seus requisitos, como aplicativos multithread e retornos de chamada.
Veja a seguir os passos para criar um rastreamento completo usando o cliente MLflow.
Crie uma instância do MLflowClient por
client = MlflowClient()
.Começar um rastreamento usando o método
client.start_trace()
. Isso inicia o contexto de rastreamento e inicia um período de raiz absoluto e retorna um objeto de período de raiz. Esse método deve ser executado antes dostart_span()
API.Defina seus atributos, entradas e saídas para o rastreamento em
client.start_trace()
.
Observação
Não há um equivalente ao método
start_trace()
nas APIs do Fluent. Isso ocorre porque o Fluent APIs inicializa automaticamente o contexto de rastreamento e determina se ele é a extensão raiz com base no estado de gerenciar.O começar() API retorna um intervalo. Obtenha a ID da solicitação, um identificador exclusivo do rastreamento, também chamado de
trace_id
, e a ID do intervalo retornado usandospan.request_id
espan.span_id
.Comece um intervalo filho usando
client.start_span(request_id, parent_id=span_id)
para definir seus atributos, entradas e saídas para o intervalo.Esse método requer
request_id
eparent_id
para associar o intervalo à posição correta na hierarquia de rastreamento. Ele retorna outro objeto span.
Encerre o intervalo filho chamando
client.end_span(request_id, span_id)
.Repita de 3 a 5 para todos os intervalos de crianças que o senhor deseja criar.
Depois que todos os períodos filhos forem encerrados, chame
client.end_trace(request_id)
para fechar todo o rastreamento e registrá-lo.
from mlflow.client import MlflowClient
mlflow_client = MlflowClient()
root_span = mlflow_client.start_trace(
name="simple-rag-agent",
inputs={
"query": "Demo",
"model_name": "DBRX",
"temperature": 0,
"max_tokens": 200
}
)
request_id = root_span.request_id
# Retrieve documents that are similar to the query
similarity_search_input = dict(query_text="demo", num_results=3)
span_ss = mlflow_client.start_span(
"search",
# Specify request_id and parent_id to create the span at the right position in the trace
request_id=request_id,
parent_id=root_span.span_id,
inputs=similarity_search_input
)
retrieved = ["Test Result"]
# Span has to be ended explicitly
mlflow_client.end_span(request_id, span_id=span_ss.span_id, outputs=retrieved)
root_span.end_trace(request_id, outputs={"output": retrieved})