[LangChain智能体本质论-07]如何在工具函数中注入各种运行时对象?

工具参数有两种来源,一种必需由外部提供,另一种则是由执行引擎根据约定自动注入,比如表示工具运行时的ToolRuntime。参数注入问题应该从两个方面来看待:

  • 注入参数的识别:模型组件在为工具调用生成ToolCall对象时,它的args不应该包含注入参数,所以在为工具生成tool_call_schame时需要具有注入参数识别的能力;
  • 调用时注入:由于ToolCallargs并不包含注入参数,所以在调用工具函数时需要补充这些漏掉的注入参数。

1. 注入参数的识别

在定义工具函数时,参数具有如下两种注入形式:一种是直接注入,另一种被成为“标记注入”。

1.1 直接注入

继承自_DirectlyInjectedToolArg类型的参数会默认为注入参数,这是一个不具有任何成员定义的“标记”类型。表示工具运行时的ToolRuntime是它的继承者。我们知道作为LangChain执行引擎的Pregel具有一个名为Runtime的运行时对象,ToolRuntime可以认为是它的一个视图或者投影。ToolRuntimeRuntime基础增加了三个字段:表示状态的state、表示当前配置的config和表示工具调用ID的tool_call_id

class _DirectlyInjectedToolArg

@dataclass
class ToolRuntime(_DirectlyInjectedToolArg, Generic[ContextT, StateT]):
    state: StateT
    context: ContextT
    config: RunnableConfig
    stream_writer: StreamWriter
    tool_call_id: str | None
    store: BaseStore | None

1.2 标记注入

标注在注入参数上的Annotated除了指定参数类型外,还可以标注一个继承自InjectedToolArg的类型指示具体注入的是什么。InjectedToolCallIdInjectedStateInjectedStoreInjectedToolArg的三个继承类型,分别用于针对工具调用ID、状态和表示长期存储的BaseStore的参数注入。在前面介绍“Agent的状态”时,我们已经演示过如何利用InjectedState向工具函数注入状态。

class InjectedToolArg

class InjectedToolCallId(InjectedToolArg)
class InjectedState(InjectedToolArg):
    def __init__(self, field: str | None = None) -> None:
        self.field = field
class InjectedStore(InjectedToolArg)

除了上面这两种类型的注入,我们在工具方法通过config参数引入的RunnableConfig也会被自动视为注入参数。

2. 调用时注入参数

作为Pregel对象的Agent来说,所有的工具都被封装在一个类型为ToolNode的节点中,当ToolNodebranch:to:tools通道的变化被驱动执行时,它会根据最近返回的AIMessage中的ToolCall列表选择相应的工具并实施调用。换句话说,大部分参数的注入是由ToolNode完成的。

虽然生成Schema时针对注入参数的识别只需要考虑参数类型或者标注类型是否继承自_DirectlyInjectedToolArg或者InjectedToolArg,但是ToolNode则需要根据具体的类型判断最终应该注入怎样的对象。
但是ToolNode只会识别参数类型ToolRuntime和针对InjectedStateInjectedStore的标注,并完成对应的参数注入。基于InjectedToolCallId针对工具调用ID的注入,以及这对RunnableConfig对象额注入是在BaseTool对象中完成的。

如下的程序演示了上述的几种针对工具函数的参数注入形式。在测试工具函数test_tool中,我们定义了五种自动注入的参数,分别表示当前执行配置、工具运行时、工具调用ID、状态和存储。方法利用断言对它们的值进行了验证,最后利用返回的Command对状态成员“tool_result”进行了设置。

from typing import Annotated, Any,Callable, Sequence, override
import builtins
from langchain.agents import create_agent
from langchain_core.language_models import BaseChatModel,LanguageModelInput
from langchain_core.messages import BaseMessage, AIMessage, ToolMessage, ToolCall
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs.chat_result import ChatResult, ChatGeneration
from langchain_core.tools import BaseTool,InjectedToolCallId
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.prebuilt import InjectedState, InjectedStore
from langchain.agents.middleware import AgentState
from langgraph.types import Command
from langgraph.prebuilt import ToolRuntime
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
from langchain_core.messages import HumanMessage

def test_tool(
        config: RunnableConfig,
        runtime: ToolRuntime,
        tool_call_id:Annotated[str, InjectedToolCallId],
        state: Annotated[dict, InjectedState],
        store: Annotated[BaseStore, InjectedStore])-> Command:  
    """A test tool that validates the inputs and returns a success message."""

    assert config == runtime.config
    assert tool_call_id == "tool_call_001" == runtime.tool_call_id
    assert state["foo"] == "123"
    assert state["bar"] == "456"
    assert store == runtime.store
    assert isinstance(store, InMemoryStore)

    return Command(
        update={
            "messages":[ToolMessage("Test tool executed successfully!", tool_call_id=tool_call_id)],
            "tool_result": "success",
            },
    )

class ExtendedAgentState(AgentState):
    foo:str
    bar:str
    tool_result:str

class ModelSimulator(BaseChatModel):
    def _generate(
        self,
        messages: list[BaseMessage],
        stop: list[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> ChatResult:
        
        if[message for message in messages if isinstance(message, ToolMessage)]:
             generation = ChatGeneration(message=AIMessage(""))
             return ChatResult(generations=[generation], llm_output={})
        
        tool_call: ToolCall = {
                "name": "test_tool",
                "args": {},
                "id": "tool_call_001",
            }            
        generation = ChatGeneration(message=AIMessage(content="", tool_calls=[tool_call]))
        return ChatResult(generations=[generation], llm_output={})

    @property
    def _llm_type(self) -> str:
        return "model-simulator"

    @override
    def bind_tools(
        self,
        tools: Sequence[builtins.dict[str, Any] | type | Callable | BaseTool],
        *,
        tool_choice: str | None = None,
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, AIMessage]:
        return self

agent = create_agent(
    model= ModelSimulator(),
    tools=[test_tool],
    store=InMemoryStore(),
    state_schema=ExtendedAgentState,
)

result = agent.invoke(input= {
    "messages": [HumanMessage(content="")], 
    "foo": "123", 
    "bar": "456"}) # type: ignore
assert result["tool_result"] == "success"

在调用create_agent函数创建Agent时,我们将model设置为用于模拟模型的ModelSimulator,它利用返回AIMessage中携带的ToolCall让Agent发起对工具test_tool的调用。除了注册test_tool工具外,我们还将state_schema设置成自定义的ExtendedAgentState对象,ExtendedAgentState在基类AgentState基础上添加了三个状态成员(foo、bar和tool_result)。我们还通过store参数设置了一个InMemoryStore对象作为Agent的长期存储。

在调用Agent时,我们指定了foo和bar两个状态成员。并从执行结果中提取并验证“tool_result”状态成员,用以确定test_tool函数确实作为工具被成功调用了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值