エージェントにトレースを追加する

重要

この機能はパブリックプレビュー段階です。

この記事では、MLflow Tracing APIs で利用できる Fluent および MLflowClient を使用してエージェントにトレースを追加する方法を説明します。

注:

MLflow Tracing の詳細な API リファレンスとコード例については、 MLflow ドキュメントを参照してください。

要件

  • MLflow 2.13.1

Fluent APIsを使用してエージェントにトレースを追加する

以下は、 Fluent APIs mlflow.tracemlflow.start_span を使用して quickstart-agent にトレースを追加する簡単な例です。


import mlflow
from mlflow.deployments import get_deploy_client


class QAChain(mlflow.pyfunc.PythonModel):
    def __init__(self):
        self.client = get_deploy_client("databricks")

    @mlflow.trace(name="quickstart-agent")
    def predict(self, model_input, system_prompt, params):
        messages = [
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {
                    "role": "user",
                    "content":  model_input[0]["query"]
                }
          ]

        traced_predict = mlflow.trace(self.client.predict)
        output = traced_predict(
            endpoint=params["model_name"],
            inputs={
                "temperature": params["temperature"],
                "max_tokens": params["max_tokens"],
                "messages": messages,
            },
        )

        with mlflow.start_span(name="_final_answer") as span:
          # Initiate another span generation
            span.set_inputs({"query": model_input[0]["query"]})

            answer = output["choices"][0]["message"]["content"]

            span.set_outputs({"generated_text": answer})
            # Attributes computed at runtime can be set using the set_attributes() method.
            span.set_attributes({
              "model_name": params["model_name"],
                        "prompt_tokens": output["usage"]["prompt_tokens"],
                        "completion_tokens": output["usage"]["completion_tokens"],
                        "total_tokens": output["usage"]["total_tokens"]
                    })
              return answer

推論の実行

コードをインストルメント化したら、通常どおり関数を実行できます。 次に、前のセクションの predict() 関数を使用した例の続きを示します。 呼び出しメソッドpredict()を実行すると、トレースが自動的に表示されます。


SYSTEM_PROMPT = """
You are an assistant for Databricks users. You are answering python, coding, SQL, data engineering, spark, data science, DW and platform, API or infrastructure administration question related to Databricks. If the question is not related to one of these topics, kindly decline to answer. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Use the following pieces of context to answer the question at the end:
"""

model = QAChain()

prediction = model.predict(
  [
      {"query": "What is in MLflow 5.0"},
  ],
  SYSTEM_PROMPT,
  {
    # Using Databricks Foundation Model for easier testing, feel free to replace it.
    "model_name": "databricks-dbrx-instruct",
    "temperature": 0.1,
    "max_tokens": 1000,
  }
)

Fluent APIs

の FluentAPIs 、コードが実行される場所と時間に基づいてトレース階層を自動的に構築します。MLflow次のセクションでは、 MLflow Tracing Fluent APIsを使用してサポートされているタスクについて説明します。

関数を修飾する

@mlflow.trace デコレーターで関数を修飾して、装飾された関数のスコープのスパンを作成できます。スパンは、関数が呼び出されたときに開始され、関数が戻ったときに終了します。 MLflow は、関数の入力と出力、および関数から発生した例外を自動的に記録します。 たとえば、次のコードを実行すると、"my_function" という名前のスパンが作成され、入力引数 x と y 、および関数の出力がキャプチャされます。

@mlflow.trace(name="agent", span_type="TYPE", attributes={"key": "value"})
def my_function(x, y):
    return x + y

トレースコンテキストマネージャーの使用

関数だけでなく、任意のコードブロックのスパンを作成する場合は、コードブロックをラップするコンテキストマネージャーとして mlflow.start_span() を使用できます。 スパンは、コンテキストが入ったときに開始し、コンテキストが終了したときに終了します。 span の入力と出力は、コンテキストマネージャから生成された span オブジェクトの setter メソッドを介して手動で提供する必要があります。

with mlflow.start_span("my_span") as span:
    span.set_inputs({"x": x, "y": y})
    result = x + y
    span.set_outputs(result)
    span.set_attribute("key", "value")

外部関数をラップする

mlflow.trace 関数は、選択した関数をトレースするためのラッパーとして使用できます。これは、外部ライブラリからインポートされた関数をトレースする場合に便利です。 その関数をデコレートすることで得られるのと同じスパンを生成します。


from sklearn.metrics import accuracy_score

y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]

traced_accuracy_score = mlflow.trace(accuracy_score)
traced_accuracy_score(y_true, y_pred)

MLflowクライアントAPIs

MlflowClient トレースを開始および終了し、スパンを管理し、スパン フィールドを設定するためのきめ細かいスレッド セーフAPIsを公開します。 これにより、トレースのライフサイクルと構造を完全に制御できます。 これらのAPIs 、マルチスレッド アプリケーションやコールバックなどの要件に対して Fluent APIs不十分な場合に役立ちます。

以下は、 MLflowクライアントを使用して完全なトレースを作成するための手順です。

  1. client = MlflowClient()を使用して MLflowClient のインスタンスを作成します。

  2. client.start_trace() メソッドを使用してトレースを開始します。これにより、トレース コンテキストが開始され、絶対ルート スパンが開始され、ルート スパン オブジェクトが返されます。 このメソッドは、 start_span() API の前に実行する必要があります。

    1. トレースの属性、入力、出力を client.start_trace()で設定します。

    注:

    Fluent APIsには start_trace() メソッドに相当するものはありません。 これは、Fluent APIsトレース コンテキストを自動的に初期化し、管理された状態に基づいてルート スパンであるかどうかを判断するためです。

  3. start_trace() API はスパンを返します。 span.request_idspan.span_idを使用して、要求 ID、トレースの一意の識別子 (trace_idとも呼ばれる)、および返されたスパンの ID を取得します。

  4. client.start_span(request_id, parent_id=span_id) を使用して子スパンを開始し、スパンの属性、入力、および出力を設定します。

    1. この方法では、スパンをトレース階層内の正しい位置に関連付けるために、 request_idparent_id が必要です。 別の span オブジェクトを返します。

  5. 子スパンを終了するには、 client.end_span(request_id, span_id)を呼び出します。

  6. 作成する子スパンに対して 3 から 5 を繰り返します。

  7. すべての子スパンが終了したら、 client.end_trace(request_id) を呼び出してトレース全体を閉じ、記録します。

from mlflow.client import MlflowClient

mlflow_client = MlflowClient()

root_span = mlflow_client.start_trace(
  name="simple-rag-agent",
  inputs={
          "query": "Demo",
          "model_name": "DBRX",
          "temperature": 0,
          "max_tokens": 200
         }
  )

request_id = root_span.request_id

# Retrieve documents that are similar to the query
similarity_search_input = dict(query_text="demo", num_results=3)

span_ss = mlflow_client.start_span(
      "search",
      # Specify request_id and parent_id to create the span at the right position in the trace
        request_id=request_id,
        parent_id=root_span.span_id,
        inputs=similarity_search_input
  )
retrieved = ["Test Result"]

# Span has to be ended explicitly
mlflow_client.end_span(request_id, span_id=span_ss.span_id, outputs=retrieved)

root_span.end_trace(request_id, outputs={"output": retrieved})