状态化 AI 代理

重要

此功能目前以公共预览版提供。

有状态的人工智能代理在交互之间保持上下文,通过使用线程 ID 来跟踪线程。 检查点使你可以将代理保存在特定的状态,而时间回溯使你可以从这些状态重播对话。 这有助于您了解非确定性语言模型 (LLM) 代理的决策过程,并执行以下操作:

  1. 观察代理:准确分析代理在每个步骤中知道和执行的动作
  2. 调试错误:确定会话流中发生错误的位置和原因
  3. 探索其他选择:重新播放和测试检查点的不同会话路径

本页演示如何使用马赛克 AI 代理框架和 LangGraph 和 Lakebase 作为内存存储创建有状态代理。

有状态智能体

要求

若要创建有状态代理,需要:

示例笔记本

以下笔记本使用此页上的概念来实现使用 Lakebase 的有状态代理:

具有线程范围的内存的有状态代理

获取笔记本

实现 LangGraph 时间旅行

使用 LangGraph 时间旅行 从检查点恢复执行。 可以重播对话或修改对话,以浏览替代路径。 每次从检查点继续时,LangGraph 会在对话历史记录中创建一个新的分支,保留原始数据,同时便于进行实验。

  1. 在代理代码中,创建在类中检索检查点历史记录和更新检查点状态的 LangGraphResponsesAgent 函数:

    from typing import List, Dict
    def get_checkpoint_history(self, thread_id: str, limit: int = 10) -> List[Dict[str, Any]]:
        """Retrieve checkpoint history for a thread.
    
        Args:
            thread_id: The thread identifier
            limit: Maximum number of checkpoints to return
    
        Returns:
            List of checkpoint information including checkpoint_id, timestamp, and next nodes
        """
        config = {"configurable": {"thread_id": thread_id}}
    
        with self.get_connection() as conn:
            checkpointer = PostgresSaver(conn)
            graph = self._create_graph(checkpointer)
    
            history = []
            for state in graph.get_state_history(config):
                if len(history) >= limit:
                    break
    
                history.append({
                    "checkpoint_id": state.config["configurable"]["checkpoint_id"],
                    "thread_id": thread_id,
                    "timestamp": state.created_at,
                    "next_nodes": state.next,
                    "message_count": len(state.values.get("messages", [])),
                    # Include last message summary for context
                    "last_message": self._get_last_message_summary(state.values.get("messages", []))
                })
    
            return history
    
    def _get_last_message_summary(self, messages: List[Any]) -> Optional[str]:
        """Get a snippet of the last message for checkpoint identification"""
        return getattr(messages[-1], "content", "")[:100] if messages else None
    
    def update_checkpoint_state(self, thread_id: str, checkpoint_id: str,
                            new_messages: Optional[List[Dict]] = None) -> Dict[str, Any]:
        """Update state at a specific checkpoint (used for modifying conversation history).
    
        Args:
            thread_id: The thread identifier
            checkpoint_id: The checkpoint to update
            new_messages: Optional new messages to set at this checkpoint
    
        Returns:
            New checkpoint configuration including the new checkpoint_id
        """
        config = {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_id": checkpoint_id
            }
        }
    
        with self.get_connection() as conn:
            checkpointer = PostgresSaver(conn)
            graph = self._create_graph(checkpointer)
    
            # Prepare the values to update
            values = {}
            if new_messages:
                cc_msgs = self.prep_msgs_for_cc_llm(new_messages)
                values["messages"] = cc_msgs
    
            # Update the state (creates a new checkpoint)
            new_config = graph.update_state(config, values=values)
    
            return {
                "thread_id": thread_id,
                "checkpoint_id": new_config["configurable"]["checkpoint_id"],
                "parent_checkpoint_id": checkpoint_id
            }
    
  2. 更新predictpredict_stream函数以支持传入检查点:

    Predict

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        """Non-streaming prediction"""
        # The same thread_id is used by BOTH predict() and predict_stream()
        ci = dict(request.custom_inputs or {})
        if "thread_id" not in ci:
            ci["thread_id"] = str(uuid.uuid4())
        request.custom_inputs = ci
    
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
    
        # Include thread_id and checkpoint_id in custom outputs
        custom_outputs = {
            "thread_id": ci["thread_id"]
        }
        if "checkpoint_id" in ci:
            custom_outputs["parent_checkpoint_id"] = ci["checkpoint_id"]
    
        try:
            history = self.get_checkpoint_history(ci["thread_id"], limit=1)
            if history:
                custom_outputs["checkpoint_id"] = history[0]["checkpoint_id"]
        except Exception as e:
            logger.warning(f"Could not retrieve new checkpoint_id: {e}")
    
        return ResponsesAgentResponse(output=outputs, custom_outputs=custom_outputs)
    

    Predict_stream

    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        """Streaming prediction with PostgreSQL checkpoint branching support.
    
        Accepts in custom_inputs:
        - thread_id: Conversation thread identifier for session
        - checkpoint_id (optional): Checkpoint to resume from (for branching)
        """
        # Get thread ID and checkpoint ID from custom inputs
        custom_inputs = request.custom_inputs or {}
        thread_id = custom_inputs.get("thread_id", str(uuid.uuid4()))  # generate new thread ID if one is not passed in
        checkpoint_id = custom_inputs.get("checkpoint_id")  # Optional for branching
    
        # Convert incoming Responses messages to LangChain format
        langchain_msgs = self.prep_msgs_for_cc_llm([i.model_dump() for i in request.input])
    
        # Build checkpoint configuration
        checkpoint_config = {"configurable": {"thread_id": thread_id}}
        # If checkpoint_id is provided, we're branching from that checkpoint
        if checkpoint_id:
            checkpoint_config["configurable"]["checkpoint_id"] = checkpoint_id
            logger.info(f"Branching from checkpoint: {checkpoint_id} in thread: {thread_id}")
    
        # DATABASE CONNECTION POOLING LOGIC FOLLOWS
        # Use connection from pool
    

然后,测试检查点分支:

  1. 启动对话线程并添加一些消息:

    from agent import AGENT
    # Initial conversation - starts a new thread
    response1 = AGENT.predict({
        "input": [{"role": "user", "content": "I'm planning for an upcoming trip!"}],
    })
    print(response1.model_dump(exclude_none=True))
    thread_id = response1.custom_outputs["thread_id"]
    
    # Within the same thread, ask a follow-up question - thread-scoped memory will remember previous messages in the same thread/conversation session
    response2 = AGENT.predict({
        "input": [{"role": "user", "content": "I'm headed to SF!"}],
        "custom_inputs": {"thread_id": thread_id}
    })
    print(response2.model_dump(exclude_none=True))
    
    # Within the same thread, ask a follow-up question - thread-scoped memory will remember previous messages in the same thread/conversation session
    response3 = AGENT.predict({
        "input": [{"role": "user", "content": "Where did I say I'm going?"}],
        "custom_inputs": {"thread_id": thread_id}
    })
    print(response3.model_dump(exclude_none=True))
    
    
  2. 检索检查点历史记录,并使用其他消息对对话进行分叉:

    # Get checkpoint history to find branching point
    history = AGENT.get_checkpoint_history(thread_id, 20)
    # Retrieve checkpoint at index - indices count backward from most recent checkpoint
    index = max(1, len(history) - 4)
    branch_checkpoint = history[index]["checkpoint_id"]
    
    # Branch from node with next_node = `('__start__',)` to re-input message to agent at certain part of conversation
    # I want to update the information of which city I am going to
    # Within the same thread, branch from a checkpoint and override it with different context to continue the conversation in a new fork
    response4 = AGENT.predict({
        "input": [{"role": "user", "content": "I'm headed to New York!"}],
        "custom_inputs": {
            "thread_id": thread_id,
            "checkpoint_id": branch_checkpoint # Branch from this checkpoint!
        }
    })
    print(response4.model_dump(exclude_none=True))
    
    # Thread ID stays the same even though it branched from a checkpoint:
    branched_thread_id = response4.custom_outputs["thread_id"]
    print(f"original thread id was {thread_id}")
    print(f"new thread id after branching is the same as original: {branched_thread_id}")
    
    # Continue the conversation in the same thread and it will pick up from the information you tell it in your branch
    response5 = AGENT.predict({
        "input": [{"role": "user", "content": "Where am I going?"}],
        "custom_inputs": {
            "thread_id": thread_id,
        }
    })
    print(response5.model_dump(exclude_none=True))
    

查询您已部署的有状态代理

将代理部署到模型服务终结点后,请参阅 查询已部署的马赛克 AI 代理 以获取查询说明。

若要传入线程 ID,请使用 extra_body 参数。 以下示例演示如何将线程 ID ResponsesAgent 传递到终结点:

   response1 = client.responses.create(
    model=endpoint,
    input=[{"role": "user", "content": "What are stateful agents?"}],
    extra_body={
        "custom_inputs": {"thread_id": thread_id}
    }
)

后续步骤