ステートフルAIエージェント
プレビュー
この機能は パブリック プレビュー段階です。
ステートフル AI エージェントは、スレッド ID を使用してスレッドを追跡し、インタラクション全体のコンテキストを維持します。チェックポイントを使用すると、エージェントを特定の状態に保存でき、タイムトラベルを使用すると、それらの状態からの会話を再生できます。これにより、非決定論的 LLM エージェントの意思決定プロセスを理解し、次のことが可能になります。
- エージェントを観察する : エージェントが各ステップで何を知っていて何をしたかを正確に分析する
- 間違いをデバッグする : 会話の流れの中でエラーが発生した場所と理由を特定する
- 代替案の検討 : チェックポイントからさまざまな会話パスを再生してテストします
このページでは、Mosaic AI エージェント フレームワークと LangGraph、そして Lakebase をメモリ ストアとして使用してステートフル エージェントを作成する方法を説明します。

要件
ステートフル エージェントを作成するには、次のものが必要です。
- Lakebase インスタンスのセットアップについては、 「データベース インスタンスの作成と管理」を参照してください。
サンプルノートブック
次のノートブックでは、このページの概念を使用して、Lakebase を使用したステートフル エージェントを実装します。
スレッドスコープメモリを備えたステートフルエージェント
LangGraph タイムトラベルを実装する
LangGraph タイムトラベルを使用して、チェックポイントから実行を再開します。会話を再生するか、会話を変更して別のパスを検討することができます。チェックポイントから再開するたびに、LangGraph は会話履歴に新しいフォークを作成し、元の会話を保持しながら実験を可能にします。
- 
エージェント コードで、 LangGraphResponsesAgentクラス内のチェックポイント履歴を取得し、チェックポイントの状態を更新する関数を作成します。Pythonfrom typing import List, Dict
 def get_checkpoint_history(self, thread_id: str, limit: int = 10) -> List[Dict[str, Any]]:
 """Retrieve checkpoint history for a thread.
 Args:
 thread_id: The thread identifier
 limit: Maximum number of checkpoints to return
 Returns:
 List of checkpoint information including checkpoint_id, timestamp, and next nodes
 """
 config = {"configurable": {"thread_id": thread_id}}
 with self.get_connection() as conn:
 checkpointer = PostgresSaver(conn)
 graph = self._create_graph(checkpointer)
 history = []
 for state in graph.get_state_history(config):
 if len(history) >= limit:
 break
 history.append({
 "checkpoint_id": state.config["configurable"]["checkpoint_id"],
 "thread_id": thread_id,
 "timestamp": state.created_at,
 "next_nodes": state.next,
 "message_count": len(state.values.get("messages", [])),
 # Include last message summary for context
 "last_message": self._get_last_message_summary(state.values.get("messages", []))
 })
 return history
 def _get_last_message_summary(self, messages: List[Any]) -> Optional[str]:
 """Get a snippet of the last message for checkpoint identification"""
 return getattr(messages[-1], "content", "")[:100] if messages else None
 def update_checkpoint_state(self, thread_id: str, checkpoint_id: str,
 new_messages: Optional[List[Dict]] = None) -> Dict[str, Any]:
 """Update state at a specific checkpoint (used for modifying conversation history).
 Args:
 thread_id: The thread identifier
 checkpoint_id: The checkpoint to update
 new_messages: Optional new messages to set at this checkpoint
 Returns:
 New checkpoint configuration including the new checkpoint_id
 """
 config = {
 "configurable": {
 "thread_id": thread_id,
 "checkpoint_id": checkpoint_id
 }
 }
 with self.get_connection() as conn:
 checkpointer = PostgresSaver(conn)
 graph = self._create_graph(checkpointer)
 # Prepare the values to update
 values = {}
 if new_messages:
 cc_msgs = self.prep_msgs_for_cc_llm(new_messages)
 values["messages"] = cc_msgs
 # Update the state (creates a new checkpoint)
 new_config = graph.update_state(config, values=values)
 return {
 "thread_id": thread_id,
 "checkpoint_id": new_config["configurable"]["checkpoint_id"],
 "parent_checkpoint_id": checkpoint_id
 }
- 
チェックポイントの受け渡しをサポートするために、 predictとpredict_stream関数を更新します。
- Predict
- Predict_stream
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
    """Non-streaming prediction"""
    # The same thread_id is used by BOTH predict() and predict_stream()
    ci = dict(request.custom_inputs or {})
    if "thread_id" not in ci:
        ci["thread_id"] = str(uuid.uuid4())
    request.custom_inputs = ci
    outputs = [
        event.item
        for event in self.predict_stream(request)
        if event.type == "response.output_item.done"
    ]
    # Include thread_id and checkpoint_id in custom outputs
    custom_outputs = {
        "thread_id": ci["thread_id"]
    }
    if "checkpoint_id" in ci:
        custom_outputs["parent_checkpoint_id"] = ci["checkpoint_id"]
    try:
        history = self.get_checkpoint_history(ci["thread_id"], limit=1)
        if history:
            custom_outputs["checkpoint_id"] = history[0]["checkpoint_id"]
    except Exception as e:
        logger.warning(f"Could not retrieve new checkpoint_id: {e}")
    return ResponsesAgentResponse(output=outputs, custom_outputs=custom_outputs)
def predict_stream(
    self,
    request: ResponsesAgentRequest,
) -> Generator[ResponsesAgentStreamEvent, None, None]:
    """Streaming prediction with PostgreSQL checkpoint branching support.
    Accepts in custom_inputs:
    - thread_id: Conversation thread identifier for session
    - checkpoint_id (optional): Checkpoint to resume from (for branching)
    """
    # Get thread ID and checkpoint ID from custom inputs
    custom_inputs = request.custom_inputs or {}
    thread_id = custom_inputs.get("thread_id", str(uuid.uuid4()))  # generate new thread ID if one is not passed in
    checkpoint_id = custom_inputs.get("checkpoint_id")  # Optional for branching
    # Convert incoming Responses messages to LangChain format
    langchain_msgs = self.prep_msgs_for_cc_llm([i.model_dump() for i in request.input])
    # Build checkpoint configuration
    checkpoint_config = {"configurable": {"thread_id": thread_id}}
    # If checkpoint_id is provided, we're branching from that checkpoint
    if checkpoint_id:
        checkpoint_config["configurable"]["checkpoint_id"] = checkpoint_id
        logger.info(f"Branching from checkpoint: {checkpoint_id} in thread: {thread_id}")
    # DATABASE CONNECTION POOLING LOGIC FOLLOWS
    # Use connection from pool
次に、チェックポイント分岐をテストします。
- 
会話スレッドを開始し、いくつかのメッセージを追加します。 Pythonfrom agent import AGENT
 # Initial conversation - starts a new thread
 response1 = AGENT.predict({
 "input": [{"role": "user", "content": "I'm planning for an upcoming trip!"}],
 })
 print(response1.model_dump(exclude_none=True))
 thread_id = response1.custom_outputs["thread_id"]
 # Within the same thread, ask a follow-up question - thread-scoped memory will remember previous messages in the same thread/conversation session
 response2 = AGENT.predict({
 "input": [{"role": "user", "content": "I'm headed to SF!"}],
 "custom_inputs": {"thread_id": thread_id}
 })
 print(response2.model_dump(exclude_none=True))
 # Within the same thread, ask a follow-up question - thread-scoped memory will remember previous messages in the same thread/conversation session
 response3 = AGENT.predict({
 "input": [{"role": "user", "content": "Where did I say I'm going?"}],
 "custom_inputs": {"thread_id": thread_id}
 })
 print(response3.model_dump(exclude_none=True))
- 
チェックポイント履歴を取得し、別のメッセージで会話をフォークします。 Python# Get checkpoint history to find branching point
 history = AGENT.get_checkpoint_history(thread_id, 20)
 # Retrieve checkpoint at index - indices count backward from most recent checkpoint
 index = max(1, len(history) - 4)
 branch_checkpoint = history[index]["checkpoint_id"]
 # Branch from node with next_node = `('__start__',)` to re-input message to agent at certain part of conversation
 # I want to update the information of which city I am going to
 # Within the same thread, branch from a checkpoint and override it with different context to continue the conversation in a new fork
 response4 = AGENT.predict({
 "input": [{"role": "user", "content": "I'm headed to New York!"}],
 "custom_inputs": {
 "thread_id": thread_id,
 "checkpoint_id": branch_checkpoint # Branch from this checkpoint!
 }
 })
 print(response4.model_dump(exclude_none=True))
 # Thread ID stays the same even though it branched from a checkpoint:
 branched_thread_id = response4.custom_outputs["thread_id"]
 print(f"original thread id was {thread_id}")
 print(f"new thread id after branching is the same as original: {branched_thread_id}")
 # Continue the conversation in the same thread and it will pick up from the information you tell it in your branch
 response5 = AGENT.predict({
 "input": [{"role": "user", "content": "Where am I going?"}],
 "custom_inputs": {
 "thread_id": thread_id,
 }
 })
 print(response5.model_dump(exclude_none=True))
デプロイされたステートフルエージェントをクエリする
エージェントをモデルサービングエンドポイントにデプロイした後、クエリの手順については、デプロイされた Mosaic AI エージェントのクエリを参照してください。
スレッド ID を渡すには、 extra_bodyパラメータを使用します。次の例は、スレッド ID をResponsesAgentエンドポイントに渡す方法を示しています。
   response1 = client.responses.create(
    model=endpoint,
    input=[{"role": "user", "content": "What are stateful agents?"}],
    extra_body={
        "custom_inputs": {"thread_id": thread_id}
    }
)