AIエージェントを作成して記録する
プレビュー
この機能はパブリックプレビュー段階です。
この記事では、Mosaic AIエージェントフレームワークを使用してRAGアプリケーションなどのAIエージェントを作成し、ログに記録する方法を説明します。
チェーンとエージェントとは何ですか?
AIシステムには多くのコンポーネントが含まれることがよくあります。たとえば、AIは、ベクターインデックスからドキュメントを取得し、それらのドキュメントを使用してプロンプトテキストを補完し、基盤モデルを使用して応答を要約する場合があります。これらのコンポーネントをつなぐコードはステップとも呼ばれ、チェーンと呼ばれます。
エージェントははるかに高度なAIシステムで、大規模言語モデルに依存して、入力に基づいてどのステップを実行するかを決定します。対照的に、チェーンは特定の結果を達成するためのハードコードされた一連のステップです。
Agent Frameworkでは、任意のライブラリやパッケージを使用してコードを作成できます。また、Agent Frameworkを使用すると、開発やテスト中にコードを繰り返し処理するのも簡単です。実際のコードを変更しなくても、追跡可能な方法でコードパラメータを変更できる設定ファイルを設定できます。
RAGエージェントの入力スキーマ
チェーンでサポートされている入力形式は次のとおりです。
(推奨)OpenAIチャット補完スキーマを使用したクエリ―。
messages
パラメータとしてオブジェクトの配列を持つ必要があります。この形式は、RAGアプリケーションに最適です。question = { "messages": [ { "role": "user", "content": "What is Retrieval-Augmented Generation?", }, { "role": "assistant", "content": "RAG, or Retrieval Augmented Generation, is a generative AI design pattern that combines a large language model (LLM) with external knowledge retrieval. This approach allows for real-time data connection to generative AI applications, improving their accuracy and quality by providing context from your data to the LLM during inference. Databricks offers integrated tools that support various RAG scenarios, such as unstructured data, structured data, tools & function calling, and agents.", }, { "role": "user", "content": "How to build RAG for unstructured data", }, ] }
SplitChatMessagesRequest
特に現在のクエリ―と履歴を個別に管理したい場合など、マルチターンチャットアプリケーションに推奨されます。{ "query": "What is MLflow", "history": [ { "role": "user", "content": "What is Retrieval-augmented Generation?" }, { "role": "assistant", "content": "RAG is" } ] }
LangChainの場合、Databricksでは、チェーンをLangChain Expression Languageで記述することを推奨しています。チェーン定義コードでは、使用している入力形式に応じてitemgetter
を使用してメッセージを取得したり、query
またはhistory
オブジェクトを使用したりできます。
RAGエージェントの出力スキーマ
コードは、サポートされている次の出力形式のいずれかに準拠する必要があります。
(推奨)ChatCompletionResponse。この形式は、OpenAI応答形式の相互運用性を持つ顧客に推奨されます。
StringResponse。この形式は、解釈が簡単で最も簡単です。
LangChainの場合、最後のチェーンステップとしてStrOutputParser()
を使用します。出力は1つの文字列値を返す必要があります。
chain = (
{
"user_query": itemgetter("messages")
| RunnableLambda(extract_user_query_string),
"chat_history": itemgetter("messages") | RunnableLambda(extract_chat_history),
}
| RunnableLambda(fake_model)
| StrOutputParser()
)
PyFuncを使用している場合、Databricksでは、タイプヒントを使用して、mlflow.models.rag_signatures
で定義されたクラスのサブクラスである入力および出力データクラスでpredict()
関数に注釈を付けることをお勧めします。
形式が確実に守られるように、predict()
内のデータクラスから出力オブジェクトを構築できます。返されたオブジェクトは、シリアル化できるように辞書表現に変換する必要があります。
パラメータを使用して品質の反復を制御する
エージェントフレームワークでは、パラメータを使用してエージェントの実行方法を制御できます。これにより、コードを変更せずにエージェントの特性を変えて素早く反復処理を行うことができます。パラメータは、Python辞書または.yaml
ファイルで定義するキーと値のペアです。
コードを構成するには、キーと値のパラメータのセットであるModelConfig
を作成します。ModelConfig
はPython辞書または.yaml
ファイルのいずれかです。たとえば、開発中に辞書を使用し、それを本番運用デプロイメントとCI/CD用の.yaml
ファイルに変換することができます。ModelConfig
の詳細については、「MLflowのドキュメント」を参照してください。
以下にModelConfig
の例を示します。
llm_parameters:
max_tokens: 500
temperature: 0.01
model_serving_endpoint: databricks-dbrx-instruct
vector_search_index: ml.docs.databricks_docs_index
prompt_template: 'You are a hello world bot. Respond with a reply to the user''s
question that indicates your prompt template came from a YAML file. Your response
must use the word "YAML" somewhere. User''s question: {question}'
prompt_template_input_vars:
- question
コードからコンフィギュレーションを呼び出すには、以下のいずれかを使用します。
# Example for loading from a .yml file
config_file = "configs/hello_world_config.yml"
model_config = mlflow.models.ModelConfig(development_config=config_file)
# Example of using a dictionary
config_dict = {
"prompt_template": "You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}",
"prompt_template_input_vars": ["question"],
"model_serving_endpoint": "databricks-dbrx-instruct",
"llm_parameters": {"temperature": 0.01, "max_tokens": 500},
}
model_config = mlflow.models.ModelConfig(development_config=config_dict)
# Use model_config.get() to retrieve a parameter value
value = model_config.get('sample_param')
エージェントをログに記録する
エージェントのログ記録は、開発プロセスの基礎です。ログ記録により、エージェントのコードと構成の「ある時点」がキャプチャされるため、構成の品質を評価できます。エージェントを開発するとき、Databricksはシリアル化ベースのログ記録ではなくコードベースのログ記録を使用することを推奨しています。各タイプのログ記録の長所と短所の情報については、「コードベースのログ記録とシリアル化ベースのログ記録」を参照してください。
このセクションでは、コードベースのログ記録の使い方を説明します。シリアル化ベースのログ記録の使用方法の詳細については、「シリアル化ベースのログ記録ワークフロー」を参照してください。
コードベースのログ記録ワークフロー
コードベースのログ記録では、エージェントまたはチェーンを記録するコードは、チェーンコードとは別のノートブックに保存する必要があります。このノートブックは、ドライバーノートブックと呼ばれます。ノートブックの例については、「ノートブックの例」を参照してください。
LangChainを使用したコードベースのログ記録ワークフロー
あなたのコードでノートブックまたはPythonファイルを作成します。この例では、ノートブックまたはファイルは
chain.py
という名前です。ノートブックまたはファイルには、ここではlc_chain
と呼ばれるLangChainチェーンが含まれている必要があります。ノートブックまたはファイルに
mlflow.models.set_model(lc_chain)
を含めます。ドライバーノートブックとして機能する新しいノートブックを作成します(この例では
driver.py
と呼ばれます)。ドライバーノートブックに
mlflow.lang_chain.log_model(lc_model=”/path/to/chain.py”)
という呼び出しを含めます。この呼び出しはchain.py
を実行し、結果をMLflowモデルに記録します。モデルをデプロイします。
サービング環境がロードされると
chain.py
が実行されます。サービングリクエストが来ると
lc_chain.invoke(...)
が呼び出されます。
PyFuncを使用したコードベースのログ記録ワークフロー
あなたのコードでノートブックまたはPythonファイルを作成します。この例では、ノートブックまたはファイルは
chain.py
という名前です。ノートブックまたはファイルには、ここではPyFuncClass
と呼ばれるPyFuncクラスが含まれている必要があります。ノートブックまたはファイルに
mlflow.models.set_model(PyFuncClass)
を含めます。ドライバーノートブックとして機能する新しいノートブックを作成します(この例では
driver.py
と呼ばれます)。ドライバーノートブックに
mlflow.pyfunc.log_model(python_model=”/path/to/chain.py”)
という呼び出しを含めます。この呼び出しはchain.py
を実行し、結果をMLflowモデルに記録します。モデルをデプロイします。
サービング環境がロードされると
chain.py
が実行されます。サービングリクエストが来ると
PyFuncClass.predict(...)
が呼び出されます。
ログチェーンのサンプルコード
import mlflow
code_path = "/Workspace/Users/first.last/chain.py"
config_path = "/Workspace/Users/first.last/config.yml"
input_example = {
"messages": [
{
"role": "user",
"content": "What is Retrieval-augmented Generation?",
}
]
}
# example using LangChain
with mlflow.start_run():
logged_chain_info = mlflow.langchain.log_model(
lc_model=code_path,
model_config=config_path, # If you specify this parameter, this is the configuration that is used for training the model. The development_config is overwritten.
artifact_path="chain", # This string is used as the path inside the MLflow model where artifacts are stored
input_example=input_example, # Must be a valid input to your chain
example_no_conversion=True, # Required
)
# or use a PyFunc model
# with mlflow.start_run():
# logged_chain_info = mlflow.pyfunc.log_model(
# python_model=chain_notebook_path,
# artifact_path="chain",
# input_example=input_example,
# example_no_conversion=True,
# )
print(f"MLflow Run: {logged_chain_info.run_id}")
print(f"Model URI: {logged_chain_info.model_uri}")
モデルが正しくログに記録されたことを確認するには、チェーンをロードしてinvoke
を呼び出します。
# Using LangChain
model = mlflow.langchain.load_model(logged_chain_info.model_uri)
model.invoke(example)
# Using PyFunc
model = mlflow.pyfunc.load_model(logged_chain_info.model_uri)
model.invoke(example)
チェーンをUnity Catalogに登録する
チェーンをデプロイする前に、チェーンをUnity Catalogに登録する必要があります。チェーンを登録すると、Unity Catalogにモデルとしてパッケージ化され、チェーン内のリソースの認可にUnity Catalogのパーミッションを使用できます。
import mlflow
mlflow.set_registry_uri("databricks-uc")
catalog_name = "test_catalog"
schema_name = "schema"
model_name = "chain_name"
model_name = catalog_name + "." + schema_name + "." + model_name
uc_model_info = mlflow.register_model(model_uri=logged_chain_info.model_uri, name=model_name)
ノートブックの例
これらのノートブックは、Databricksでチェーンアプリケーションを作成する方法を示すために、単純な「Hello, world」チェーンを作成します。最初の例では、単純なチェーンを作成します。2番目のノートブック例では、開発中にコードの変更を最小限に抑えるためにパラメータを使用する方法を示します。
コードベースのログ記録とシリアライズベースのログ記録
チェーンを作成してログに記録するには、コードベースのMLflowログ記録またはシリアル化ベースのMLflowログ記録を使用できます。Databricksでは、コードベースのログ記録を使用することをお勧めします。
コードベースのMLflowログ記録では、チェーンのコードはPythonファイルとしてキャプチャされます。Python環境はパッケージのリストとしてキャプチャされます。チェーンがデプロイされると、Python環境が復元され、チェーンのコードが実行されてチェーンがメモリにロードされます。
シリアル化ベースのMLflowログ記録では、チェーンのコードとPython環境の現在の状態が、多くの場合pickle
やjoblib
などのライブラリを使用してディスクにシリアル化されます。チェーンがデプロイされると、Python環境が復元され、シリアル化されたオブジェクトがメモリに読み込まれ、エンドポイントが呼び出されたときに呼び出せるようになります。
表は、各方法の長所と短所を示しています。
メソッド |
利点 |
デメリット |
---|---|---|
コードベースのMLflowログ記録 |
|
|
シリアル化ベースのMLflowログ記録 |
|
|
シリアル化ベースのログ記録ワークフロー
Databricks では、シリアル化ベースのロギングではなく、コードベースのロギングを使用することを推奨しています。 コードベースのログ記録の使用方法の詳細については、「 コードベースのログ記録ワークフロー」を参照してください。
このセクションでは、シリアル化ベースのログ記録の使用方法について説明します。
LangChainを使用したシリアル化ベースのログ記録ワークフロー
あなたのコードでノートブックまたはPythonファイルを作成します。ノートブックまたはファイルには、ここでは
lc_chain
と呼ばれるLangChainチェーンが含まれている必要があります。ノートブックまたはファイルに
mlflow.lang_chain.log_model(lc_model=lc_chain)
を含めます。PyFuncClass()
のシリアル化されたコピーがMLflowモデルに記録されます。モデルをデプロイします。
サービス環境がロードされると、
PyFuncClass
はデシリアル化されます。サービングリクエストが来ると
lc_chain.invoke(...)
が呼び出されます。
PyFuncによるシリアライズベースのログ記録ワークフロー
あなたのコードでノートブックまたはPythonファイルを作成します。この例では、ノートブックまたはファイルは
notebook.py
という名前です。ノートブックまたはファイルには、ここではPyFuncClass
と呼ばれるPyFuncクラスが含まれている必要があります。notebook.py
にmlflow.pyfunc.log_model(python_model=PyFuncClass())
を含めます。PyFuncClass()
のシリアル化されたコピーがMLflowモデルに記録されます。モデルをデプロイします。
サービス環境がロードされると、
PyFuncClass
はデシリアル化されます。サービングリクエストが来ると
PyFuncClass.predict(...)
が呼び出されます。