メインコンテンツまでスキップ

AIエージェントメモリ

備考

プレビュー

この機能は パブリック プレビュー段階です。

メモリにより、AI エージェントは会話の前の部分や以前の会話からの情報を記憶できます。これにより、エージェントはコンテキストに応じた応答を提供し、時間の経過とともにパーソナライズされたエクスペリエンスを構築できるようになります。フルマネージド Postgres OLTP データベースであるDatabricks Lakebase を使用して、会話の状態と履歴を管理します。

要件

短期記憶と長期記憶

短期記憶は単一の会話セッション内のコンテキストをキャプチャしますが、長期記憶は複数の会話にわたる重要な情報を抽出して保存します。どちらか一方または両方のタイプのメモリを使用してエージェントを構築できます。

短期記憶と長期記憶を持つエージェント

短期記憶

長期記憶

スレッドIDとチェックポイントを使用して、単一の会話セッションでコンテキストをキャプチャします。

セッション中のフォローアップの質問の文脈を維持する

タイムトラベルを使用して会話フローをデバッグおよびテストする

複数のセッションにわたって重要な知識を自動的に抽出して保存

過去の好みに基づいてやり取りをパーソナライズする

時間の経過とともに応答を改善するユーザーに関する知識ベースを構築する

ノートブックの例

短期記憶を持つエージェント

Open notebook in new tab

長期記憶を持つエージェント

Open notebook in new tab

デプロイされたエージェントをクエリする

エージェントをモデルサービングエンドポイントにデプロイした後、クエリの手順については、デプロイされた Mosaic AI エージェントのクエリを参照してください。

スレッド ID を渡すには、 extra_bodyパラメータを使用します。次の例は、スレッド ID をResponsesAgentエンドポイントに渡す方法を示しています。

Python
   response1 = client.responses.create(
model=endpoint,
input=[{"role": "user", "content": "What are stateful agents?"}],
extra_body={
"custom_inputs": {"thread_id": thread_id}
}
)

Playground や Review アプリのようにChatContext を自動的に渡すクライアントを使用している場合は、短期/長期メモリの使用例で会話 ID とユーザー ID が自動的に渡されます。

短期記憶のタイムトラベル

短期記憶を持つエージェントの場合は、 LangGraph タイムトラベルを使用してチェックポイントから実行を再開します。会話を再生するか、会話を変更して別のパスを検討することができます。チェックポイントから再開するたびに、LangGraph は会話履歴に新しいフォークを作成し、元の会話を保持しながら実験を可能にします。

  1. エージェント コードで、 LangGraphResponsesAgentクラス内のチェックポイント履歴を取得し、チェックポイントの状態を更新する関数を作成します。

    Python
    from typing import List, Dict
    def get_checkpoint_history(self, thread_id: str, limit: int = 10) -> List[Dict[str, Any]]:
    """Retrieve checkpoint history for a thread.

    Args:
    thread_id: The thread identifier
    limit: Maximum number of checkpoints to return

    Returns:
    List of checkpoint information including checkpoint_id, timestamp, and next nodes
    """
    config = {"configurable": {"thread_id": thread_id}}

    with CheckpointSaver(instance_name=LAKEBASE_INSTANCE_NAME) as checkpointer:
    graph = self._create_graph(checkpointer)

    history = []
    for state in graph.get_state_history(config):
    if len(history) >= limit:
    break

    history.append({
    "checkpoint_id": state.config["configurable"]["checkpoint_id"],
    "thread_id": thread_id,
    "timestamp": state.created_at,
    "next_nodes": state.next,
    "message_count": len(state.values.get("messages", [])),
    # Include last message summary for context
    "last_message": self._get_last_message_summary(state.values.get("messages", []))
    })

    return history

    def _get_last_message_summary(self, messages: List[Any]) -> Optional[str]:
    """Get a snippet of the last message for checkpoint identification"""
    return getattr(messages[-1], "content", "")[:100] if messages else None

    def update_checkpoint_state(self, thread_id: str, checkpoint_id: str,
    new_messages: Optional[List[Dict]] = None) -> Dict[str, Any]:
    """Update state at a specific checkpoint (used for modifying conversation history).

    Args:
    thread_id: The thread identifier
    checkpoint_id: The checkpoint to update
    new_messages: Optional new messages to set at this checkpoint

    Returns:
    New checkpoint configuration including the new checkpoint_id
    """
    config = {
    "configurable": {
    "thread_id": thread_id,
    "checkpoint_id": checkpoint_id
    }
    }

    with CheckpointSaver(instance_name=LAKEBASE_INSTANCE_NAME) as checkpointer:
    graph = self._create_graph(checkpointer)

    # Prepare the values to update
    values = {}
    if new_messages:
    cc_msgs = self.prep_msgs_for_cc_llm(new_messages)
    values["messages"] = cc_msgs

    # Update the state (creates a new checkpoint)
    new_config = graph.update_state(config, values=values)

    return {
    "thread_id": thread_id,
    "checkpoint_id": new_config["configurable"]["checkpoint_id"],
    "parent_checkpoint_id": checkpoint_id
    }
  2. チェックポイントの受け渡しをサポートするために、 predictpredict_stream関数を更新します。

Python
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
"""Non-streaming prediction"""
# The same thread_id is used by BOTH predict() and predict_stream()
ci = dict(request.custom_inputs or {})
if "thread_id" not in ci:
ci["thread_id"] = str(uuid.uuid4())
request.custom_inputs = ci

outputs = [
event.item
for event in self.predict_stream(request)
if event.type == "response.output_item.done"
]

# Include thread_id and checkpoint_id in custom outputs
custom_outputs = {
"thread_id": ci["thread_id"]
}
if "checkpoint_id" in ci:
custom_outputs["parent_checkpoint_id"] = ci["checkpoint_id"]

try:
history = self.get_checkpoint_history(ci["thread_id"], limit=1)
if history:
custom_outputs["checkpoint_id"] = history[0]["checkpoint_id"]
except Exception as e:
logger.warning(f"Could not retrieve new checkpoint_id: {e}")

return ResponsesAgentResponse(output=outputs, custom_outputs=custom_outputs)

次に、チェックポイント分岐をテストします。

  1. 会話スレッドを開始し、いくつかのメッセージを追加します。

    Python
    from agent import AGENT
    # Initial conversation - starts a new thread
    response1 = AGENT.predict({
    "input": [{"role": "user", "content": "I'm planning for an upcoming trip!"}],
    })
    print(response1.model_dump(exclude_none=True))
    thread_id = response1.custom_outputs["thread_id"]

    # Within the same thread, ask a follow-up question - short-term memory will remember previous messages in the same thread/conversation session
    response2 = AGENT.predict({
    "input": [{"role": "user", "content": "I'm headed to SF!"}],
    "custom_inputs": {"thread_id": thread_id}
    })
    print(response2.model_dump(exclude_none=True))

    # Within the same thread, ask a follow-up question - short-term memory will remember previous messages in the same thread/conversation session
    response3 = AGENT.predict({
    "input": [{"role": "user", "content": "Where did I say I'm going?"}],
    "custom_inputs": {"thread_id": thread_id}
    })
    print(response3.model_dump(exclude_none=True))

  2. チェックポイント履歴を取得し、別のメッセージで会話をフォークします。

    Python
    # Get checkpoint history to find branching point
    history = AGENT.get_checkpoint_history(thread_id, 20)
    # Retrieve checkpoint at index - indices count backward from most recent checkpoint
    index = max(1, len(history) - 4)
    branch_checkpoint = history[index]["checkpoint_id"]

    # Branch from node with next_node = `('__start__',)` to re-input message to agent at certain part of conversation
    # I want to update the information of which city I am going to
    # Within the same thread, branch from a checkpoint and override it with different context to continue the conversation in a new fork
    response4 = AGENT.predict({
    "input": [{"role": "user", "content": "I'm headed to New York!"}],
    "custom_inputs": {
    "thread_id": thread_id,
    "checkpoint_id": branch_checkpoint # Branch from this checkpoint!
    }
    })
    print(response4.model_dump(exclude_none=True))

    # Thread ID stays the same even though it branched from a checkpoint:
    branched_thread_id = response4.custom_outputs["thread_id"]
    print(f"original thread id was {thread_id}")
    print(f"new thread id after branching is the same as original: {branched_thread_id}")

    # Continue the conversation in the same thread and it will pick up from the information you tell it in your branch
    response5 = AGENT.predict({
    "input": [{"role": "user", "content": "Where am I going?"}],
    "custom_inputs": {
    "thread_id": thread_id,
    }
    })
    print(response5.model_dump(exclude_none=True))

次のステップ