AIエージェントを作成して記録する

プレビュー

この機能はパブリックプレビュー段階です。

この記事では、Mosaic AIエージェントフレームワークを使用してRAGアプリケーションなどのAIエージェントを作成し、ログに記録する方法を説明します。

チェーンとエージェントとは何ですか?

AIシステムには多くのコンポーネントが含まれることがよくあります。たとえば、AIは、ベクターインデックスからドキュメントを取得し、それらのドキュメントを使用してプロンプトテキストを補完し、基盤モデルを使用して応答を要約する場合があります。これらのコンポーネントをつなぐコードはステップとも呼ばれ、チェーンと呼ばれます

エージェントははるかに高度なAIシステムで、大規模言語モデルに依存して、入力に基づいてどのステップを実行するかを決定します。対照的に、チェーンは特定の結果を達成するためのハードコードされた一連のステップです。

Agent Frameworkでは、任意のライブラリやパッケージを使用してコードを作成できます。また、Agent Frameworkを使用すると、開発やテスト中にコードを繰り返し処理するのも簡単です。実際のコードを変更しなくても、追跡可能な方法でコードパラメータを変更できる設定ファイルを設定できます。

要件

DataBricksが管理するベクター検索インデックスを使用するエージェントの場合、ベクターインデックスで自動認証を使用するには、mlflowバージョン2.13.1以降が必要です。

RAGエージェントの入力スキーマ

チェーンでサポートされている入力形式は次のとおりです。

  • 推奨)OpenAIチャット補完スキーマを使用したクエリ―。messagesパラメータとしてオブジェクトの配列を持つ必要があります。この形式は、RAGアプリケーションに最適です。

    question = {
        "messages": [
            {
                "role": "user",
                "content": "What is Retrieval-Augmented Generation?",
            },
            {
                "role": "assistant",
                "content": "RAG, or Retrieval Augmented Generation, is a generative AI design pattern that combines a large language model (LLM) with external knowledge retrieval. This approach allows for real-time data connection to generative AI applications, improving their accuracy and quality by providing context from your data to the LLM during inference. Databricks offers integrated tools that support various RAG scenarios, such as unstructured data, structured data, tools & function calling, and agents.",
            },
            {
                "role": "user",
                "content": "How to build RAG for unstructured data",
            },
        ]
    }
    
  • SplitChatMessagesRequest特に現在のクエリ―と履歴を個別に管理したい場合など、マルチターンチャットアプリケーションに推奨されます。

    {
    "query": "What is MLflow",
    "history": [
      {
      "role": "user",
      "content": "What is Retrieval-augmented Generation?"
      },
      {
      "role": "assistant",
      "content": "RAG is"
      }
      ]
    }
    

LangChainの場合、Databricksでは、チェーンをLangChain Expression Languageで記述することを推奨しています。チェーン定義コードでは、使用している入力形式に応じてitemgetterを使用してメッセージを取得したり、queryまたはhistoryオブジェクトを使用したりできます。

RAGエージェントの出力スキーマ

コードは、サポートされている次の出力形式のいずれかに準拠する必要があります。

  • (推奨)ChatCompletionResponse。この形式は、OpenAI応答形式の相互運用性を持つ顧客に推奨されます。

  • StringResponse。この形式は、解釈が簡単で最も簡単です。

LangChainの場合、最後のチェーンステップとしてStrOutputParser()を使用します。出力は1つの文字列値を返す必要があります。

  chain = (
      {
          "user_query": itemgetter("messages")
          | RunnableLambda(extract_user_query_string),
          "chat_history": itemgetter("messages") | RunnableLambda(extract_chat_history),
      }
      | RunnableLambda(fake_model)
      | StrOutputParser()
  )

PyFuncを使用している場合、Databricksでは、タイプヒントを使用して、mlflow.models.rag_signaturesで定義されたクラスのサブクラスである入力および出力データクラスでpredict()関数に注釈を付けることをお勧めします。

形式が確実に守られるように、predict()内のデータクラスから出力オブジェクトを構築できます。返されたオブジェクトは、シリアル化できるように辞書表現に変換する必要があります。

パラメータを使用して品質の反復を制御する

エージェントフレームワークでは、パラメータを使用してエージェントの実行方法を制御できます。これにより、コードを変更せずにエージェントの特性を変えて素早く反復処理を行うことができます。パラメータは、Python辞書または.yamlファイルで定義するキーと値のペアです。

コードを構成するには、キーと値のパラメータのセットであるModelConfigを作成します。ModelConfigはPython辞書または.yamlファイルのいずれかです。たとえば、開発中に辞書を使用し、それを本番運用デプロイメントとCI/CD用の.yamlファイルに変換することができます。ModelConfigの詳細については、「MLflowのドキュメント」を参照してください。

以下にModelConfigの例を示します。

llm_parameters:
  max_tokens: 500
  temperature: 0.01
model_serving_endpoint: databricks-dbrx-instruct
vector_search_index: ml.docs.databricks_docs_index
prompt_template: 'You are a hello world bot. Respond with a reply to the user''s
  question that indicates your prompt template came from a YAML file. Your response
  must use the word "YAML" somewhere. User''s question: {question}'
prompt_template_input_vars:
- question

コードからコンフィギュレーションを呼び出すには、以下のいずれかを使用します。

# Example for loading from a .yml file
config_file = "configs/hello_world_config.yml"
model_config = mlflow.models.ModelConfig(development_config=config_file)

# Example of using a dictionary
config_dict = {
    "prompt_template": "You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}",
    "prompt_template_input_vars": ["question"],
    "model_serving_endpoint": "databricks-dbrx-instruct",
    "llm_parameters": {"temperature": 0.01, "max_tokens": 500},
}

model_config = mlflow.models.ModelConfig(development_config=config_dict)

# Use model_config.get() to retrieve a parameter value
value = model_config.get('sample_param')

エージェントをログに記録する

エージェントのログ記録は、開発プロセスの基礎です。ログ記録により、エージェントのコードと構成の「ある時点」がキャプチャされるため、構成の品質を評価できます。エージェントを開発するとき、Databricksはシリアル化ベースのログ記録ではなくコードベースのログ記録を使用することを推奨しています。各タイプのログ記録の長所と短所の情報については、「コードベースのログ記録とシリアル化ベースのログ記録」を参照してください。

このセクションでは、コードベースのログ記録の使い方を説明します。シリアル化ベースのログ記録の使用方法の詳細については、「シリアル化ベースのログ記録ワークフロー」を参照してください。

コードベースのログ記録ワークフロー

コードベースのログ記録では、エージェントまたはチェーンを記録するコードは、チェーンコードとは別のノートブックに保存する必要があります。このノートブックは、ドライバーノートブックと呼ばれます。ノートブックの例については、「ノートブックの例」を参照してください。

LangChainを使用したコードベースのログ記録ワークフロー

  1. あなたのコードでノートブックまたはPythonファイルを作成します。この例では、ノートブックまたはファイルはchain.pyという名前です。ノートブックまたはファイルには、ここではlc_chainと呼ばれるLangChainチェーンが含まれている必要があります。

  2. ノートブックまたはファイルにmlflow.models.set_model(lc_chain)を含めます。

  3. ドライバーノートブックとして機能する新しいノートブックを作成します(この例ではdriver.pyと呼ばれます)。

  4. ドライバーノートブックにmlflow.lang_chain.log_model(lc_model=”/path/to/chain.py”)という呼び出しを含めます。この呼び出しはchain.pyを実行し、結果をMLflowモデルに記録します。

  5. モデルをデプロイします。

  6. サービング環境がロードされるとchain.pyが実行されます。

  7. サービングリクエストが来るとlc_chain.invoke(...)が呼び出されます。

PyFuncを使用したコードベースのログ記録ワークフロー

  1. あなたのコードでノートブックまたはPythonファイルを作成します。この例では、ノートブックまたはファイルはchain.pyという名前です。ノートブックまたはファイルには、ここではPyFuncClassと呼ばれるPyFuncクラスが含まれている必要があります。

  2. ノートブックまたはファイルにmlflow.models.set_model(PyFuncClass)を含めます。

  3. ドライバーノートブックとして機能する新しいノートブックを作成します(この例ではdriver.pyと呼ばれます)。

  4. ドライバーノートブックにmlflow.pyfunc.log_model(python_model=”/path/to/chain.py”)という呼び出しを含めます。この呼び出しはchain.pyを実行し、結果をMLflowモデルに記録します。

  5. モデルをデプロイします。

  6. サービング環境がロードされるとchain.pyが実行されます。

  7. サービングリクエストが来るとPyFuncClass.predict(...)が呼び出されます。

ログチェーンのサンプルコード

import mlflow

code_path = "/Workspace/Users/first.last/chain.py"
config_path = "/Workspace/Users/first.last/config.yml"

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "What is Retrieval-augmented Generation?",
        }
    ]
}

# example using LangChain
with mlflow.start_run():
  logged_chain_info = mlflow.langchain.log_model(
    lc_model=code_path,
    model_config=config_path, # If you specify this parameter, this is the configuration that is used for training the model. The development_config is overwritten.
    artifact_path="chain", # This string is used as the path inside the MLflow model where artifacts are stored
    input_example=input_example, # Must be a valid input to your chain
    example_no_conversion=True, # Required
  )

# or use a PyFunc model
# with mlflow.start_run():
#   logged_chain_info = mlflow.pyfunc.log_model(
#     python_model=chain_notebook_path,
#     artifact_path="chain",
#     input_example=input_example,
#     example_no_conversion=True,
#   )

print(f"MLflow Run: {logged_chain_info.run_id}")
print(f"Model URI: {logged_chain_info.model_uri}")

モデルが正しくログに記録されたことを確認するには、チェーンをロードしてinvokeを呼び出します。

# Using LangChain
model = mlflow.langchain.load_model(logged_chain_info.model_uri)
model.invoke(example)

# Using PyFunc
model = mlflow.pyfunc.load_model(logged_chain_info.model_uri)
model.invoke(example)

チェーンをUnity Catalogに登録する

チェーンをデプロイする前に、チェーンをUnity Catalogに登録する必要があります。チェーンを登録すると、Unity Catalogにモデルとしてパッケージ化され、チェーン内のリソースの認可にUnity Catalogのパーミッションを使用できます。

import mlflow

mlflow.set_registry_uri("databricks-uc")

catalog_name = "test_catalog"
schema_name = "schema"
model_name = "chain_name"

model_name = catalog_name + "." + schema_name + "." + model_name
uc_model_info = mlflow.register_model(model_uri=logged_chain_info.model_uri, name=model_name)

ノートブックの例

これらのノートブックは、Databricksでチェーンアプリケーションを作成する方法を示すために、単純な「Hello, world」チェーンを作成します。最初の例では、単純なチェーンを作成します。2番目のノートブック例では、開発中にコードの変更を最小限に抑えるためにパラメータを使用する方法を示します。

シンプルなチェーンノートブック

ノートブックを新しいタブで開く

シンプルチェーンドライバーノート

ノートブックを新しいタブで開く

パラメータ化されたチェーンノートブック

ノートブックを新しいタブで開く

パラメータ化されたチェーンドライバーノートブック

ノートブックを新しいタブで開く

コードベースのログ記録とシリアライズベースのログ記録

チェーンを作成してログに記録するには、コードベースのMLflowログ記録またはシリアル化ベースのMLflowログ記録を使用できます。Databricksでは、コードベースのログ記録を使用することをお勧めします。

コードベースのMLflowログ記録では、チェーンのコードはPythonファイルとしてキャプチャされます。Python環境はパッケージのリストとしてキャプチャされます。チェーンがデプロイされると、Python環境が復元され、チェーンのコードが実行されてチェーンがメモリにロードされます。

シリアル化ベースのMLflowログ記録では、チェーンのコードとPython環境の現在の状態が、多くの場合picklejoblibなどのライブラリを使用してディスクにシリアル化されます。チェーンがデプロイされると、Python環境が復元され、シリアル化されたオブジェクトがメモリに読み込まれ、エンドポイントが呼び出されたときに呼び出せるようになります。

表は、各方法の長所と短所を示しています。

メソッド

利点

デメリット

コードベースのMLflowログ記録

  • 多くの一般的な生成AIライブラリでサポートされていないシリアル化の固有の制限を克服します。

  • 後で参照できるように、元のコードのコピーを保存します。

  • コードをシリアル化できる単一のオブジェクトに再構築する必要はありません。

log_model(...) チェーンのコードとは別のノートブック(ドライバーノートブックと呼ばれます)から呼び出す必要があります。

シリアル化ベースのMLflowログ記録

log_model(...) モデルが定義されている同じノートブックから呼び出すことができます。

  • 元のコードは使用できません。

  • チェーンで使用されるすべてのライブラリとオブジェクトは、シリアル化をサポートしている必要があります。

シリアル化ベースのログ記録ワークフロー

Databricks では、シリアル化ベースのロギングではなく、コードベースのロギングを使用することを推奨しています。 コードベースのログ記録の使用方法の詳細については、「 コードベースのログ記録ワークフロー」を参照してください。

このセクションでは、シリアル化ベースのログ記録の使用方法について説明します。

LangChainを使用したシリアル化ベースのログ記録ワークフロー

  1. あなたのコードでノートブックまたはPythonファイルを作成します。ノートブックまたはファイルには、ここではlc_chainと呼ばれるLangChainチェーンが含まれている必要があります。

  2. ノートブックまたはファイルにmlflow.lang_chain.log_model(lc_model=lc_chain)を含めます。

  3. PyFuncClass()のシリアル化されたコピーがMLflowモデルに記録されます。

  4. モデルをデプロイします。

  5. サービス環境がロードされると、PyFuncClassはデシリアル化されます。

  6. サービングリクエストが来るとlc_chain.invoke(...)が呼び出されます。

PyFuncによるシリアライズベースのログ記録ワークフロー

  1. あなたのコードでノートブックまたはPythonファイルを作成します。この例では、ノートブックまたはファイルはnotebook.pyという名前です。ノートブックまたはファイルには、ここではPyFuncClassと呼ばれるPyFuncクラスが含まれている必要があります。

  2. notebook.pymlflow.pyfunc.log_model(python_model=PyFuncClass())を含めます。

  3. PyFuncClass()のシリアル化されたコピーがMLflowモデルに記録されます。

  4. モデルをデプロイします。

  5. サービス環境がロードされると、PyFuncClassはデシリアル化されます。

  6. サービングリクエストが来るとPyFuncClass.predict(...)が呼び出されます。