Langchain源码剖析 Chains系列(一)

2023年 10月 15日 72.9k 0

什么是chains

Langchain是一个由多种组件构成的系统,其中每个组件都有其特定的职能。这里的'Chain',如其名所示,就像一条串联这些组件的链条,使我们能够实现更复杂的功能。在这个系统中,组件可能包括但不限于prompt、memory、retrieval和model,甚至可以是chain本身。以Prompt、Memory、Model为例,我们可以组合这三个组件,创建一个应用程序,该程序能接收用户输入,通过大型模型获取反馈,并将反馈记录在Memory组件中,作为下一次输入的参数。

此外,Langchain还内置了各种类型的chains,满足不同的需求。当然,如果它们不能满足你的特定需求,Langchain也支持自定义chain,以便更好地适应你的业务场景。

快速开始

LLMChain 是一个最简单的链,它接受一个提示模板,使用用户输入对其进行格式化,并从 LLM 返回响应。要使用 LLMChain,我们首先要创建一个提示模板,这个模块接受一个水果的输入变量,用于描述对应水果的特征。

from langchain.prompts import PromptTemplate

prompt = PromptTemplate(
    input_variables=["fruit"],
    template="Please describe the characteristics of the following fruits {fruit}?",
)

接着我们在创建一个Language Model (LLM), 并和prompt组合构造一个最简单的Chain。要使用这个Chain,我们只需调用'run'方法即可。在这个过程中,LLMChain会帮助我们格式化Prompt字符串,然后将其用作输入,调用LLM生成结果。这样,我们能够轻松地获得对应水果特征的描述。

from langchain.llms import OpenAI
from langchain.chains import LLMChain
 
llm = OpenAI(temperature=0.9)
chain = LLMChain(llm=llm, prompt=prompt)
# Bananas are a curved, yellow fruit with a thick, edible peel. They are a popular snack and are often eaten raw. Bananas are a good source of potassium, fiber, and vitamin B6. They are also low in calories and fat, making them a healthy snack. Bananas are also used in baking, smoothies, and other recipes.
print(chain.run("banana"))  # 等价于 chain.run({"fruit": "banana"})

只有在 Chain 的 output_keys 只有一个元素时,才可以使用 run 方法。另外对于只有一个输入变量的情况下,才直接输入字符串,而无需指定输入dict类型,否则需要使用dict类型。

Langchain的chains模块不仅支持使用run方法,还提供了__call__方法。不同于run方法返回的是字符串,当我们调用__call__时,它将返回一个字典,其中包含输入参数。然而,如果我们希望结果中不包括输入参数,可以在调用__call__方法时指定 return_only_outputs=True。这样,在保证功能丰富的同时,我们能更灵活地控制输出的内容。

# {'fruit': 'banana', 'text': 'nnBananas are a tropical fruit that are curved in shape and have a yellow peel when ripe. They are sweet and creamy in taste and have a soft, starchy texture. Bananas are a good source of potassium, fiber, and vitamin B6. They are also low in calories and fat.'}
print(chain("banana"))
# {'text': 'nnBananas are a sweet, yellow fruit with a curved shape and a soft, creamy texture. They have a thick peel that can be easily peeled off to reveal the soft, sweet flesh inside. Bananas are a great source of potassium, fiber, and vitamin B6. They are also low in calories and fat, making them a healthy snack.'}
print(chain("banana", return_only_outputs=True))

我们也可以上面使用的Language Model 换成聊天模型:

from langchain.prompts.chat import HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain

human_message_prompt = HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="Please describe the characteristics of the following fruits {fruit}?",
            input_variables=["fruit"],
        )
    )

prompt = ChatPromptTemplate.from_messages([human_message_prompt])

llm = ChatOpenAI(temperature=0.1)
chain = LLMChain(llm=llm, prompt=prompt)
"""
Bananas are a popular fruit known for their distinctive shape, vibrant yellow color, and sweet taste. Here are some characteristics of bananas:

1. Shape and Size: Bananas are elongated and curved, resembling a crescent moon. They typically measure around 6 to 8 inches in length, but can vary in size.

2. Color: When ripe, bananas have a bright yellow peel. However, they can also be green when unripe or develop brown spots as they become overripe.

3. Texture: The texture of a banana is smooth and creamy when ripe, making it easy to eat. The flesh is soft and tender, with a slight firmness.

Overall, bananas are a delicious and nutritious fruit that is loved for their convenience, versatility, and sweet taste.
"""
print(chain.run("banana"))

添加缓存

Langchain的chains模块支持使用'Memory'参数,该参数必须是'BaseMemory'类型。一旦添加了这个参数,chain对象便能在多次调用中持久化并保持数据。简单地说,通过引入'Memory'参数使得chain对象从无状态变为有状态,意味着它可以储存和跨时间维护信息。

from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationChain, LLMChain
from langchain.memory import ConversationBufferMemory


llm = ChatOpenAI(temperature=0.1)


conversation = ConversationChain(
    llm=llm,
    memory=ConversationBufferMemory()
)

print(conversation.run("Answer briefly. What are the first 3 colors of a rainbow?"))
# -> The first three colors of a rainbow are red, orange, and yellow.
print(conversation.run("And the next 4?"))
# -> The next four colors of a rainbow are green, blue, indigo, and violet.

源码分析

在本次的源码分析中,我们将专注于探索Langchain中三个核心类:Chain、LLMChain和ConversationChain。每个类都在实现链式处理逻辑中起到关键作用。具体来说,

  • Chain是一个基类,提供了公共的接口和一些默认行为。
  • LLMChain是Chain的子类,特点是利用语言模型 (Language Model, 缩写为LLM) 进行工作。
  • ConversationChain则用于管理和执行基于会话的交互。

请注意,这里的分析仅限于一些关键的同步方法,异步方法等其它内容将在后续文章中介绍。此外,其他类型如ConversationalRetrievalChain、RouterChain等也将在后续文章中进行深入研究。

image.png

添加图片注释,不超过 140 字(可选)

Chain

Chain位于langchain.chains目录下的base.py文件中,从文件命名就可以看出这是一个基类,其中包含了一些关键属性和方法。对于属性,包含memory、callbacks、verbose等。而对于方法,既有抽象方法,又有具体实现的方法。抽象方法定义了所有派生Chain类必须遵循的接口,具体实现的方法为所有派生类提供了通用的功能。

总的来说,这个基类为Langchain创建了一个灵活的架构,使得开发者能够通过创建新的Chain子类来快速扩展和自定义功能。

属性

在langchain.chains的base.py文件中定义的Chain类有几个关键属性值。分别为memory、callbacks、callback_manager、verbose、tags和metadata。

  • memory: 这是一个可选的BaseMemory对象,默认为None。Memory类在每个链条开始和结束时被调用。开始时,Memory加载变量并在链中传递它们。最后,它保存任何返回的变量。
  • callbacks: 这是一个可选的回调处理器列表(或回调管理器),默认为None。在链条的生命周期中,会始终调用回调处理器,从'on_chain_start'开始, 到'on_chain_end' 或 'on_chain_error'结束。
  • callback_manager: 这是一个已弃用的属性,建议使用callbacks替代。
  • verbose: 决定是否在详细模式下运行。在详细模式下,一些中间日志将打印到控制台。默认值是langchain.verbose。
  • tags: 这是与链条相关联的可选标签列表,默认为无。这些标签会与每次对此链条的调用相关联,并作为参数传递给callbacks中定义的处理器。可以用这些标签来识别特定的链条实例及其用例。
  • metadata: 这是与链条相关联的可选元数据,默认为无。这些元数据会与每次对此链条的调用相关联,并作为参数传递给callbacks中定义的处理器。可以用这些元数据来识别特定的链条实例及其用例。
  • 总结来说,这些属性让Chain类具有了更多灵活性并能够支持更复杂的操作,包括内存管理、回调处理,以及详细模式的控制等等。

    方法

    Chain接口设计得易于创建具有以下特性的应用程序:

  • 有状态 (Stateful): 向任何Chain添加Memory,使它具有状态。这意味着Chain可以在多次运行之间记住并持久化其数据。
  • 可观察(Observable): 将Callbacks传递给Chain,以执行额外的功能,例如在组件调用的主流程之外进行日志记录。这让你可以更好地监控和管理链条的运行过程。
  • 可组合(Composable): Chain API足够灵活,可以容易地将Chains与其他组件(包括其他Chains)结合在一起。这样,你可以自由地设计和构建复杂的流程。
  • Chain类主要暴露出两种方法:

  • call: Chains是可调用的。__call__方法是执行Chain的主要方式。该方法接收一个字典作为输入,并返回一个字典作为输出。这种方式让你能对输入和输出进行精细的控制。
  • run: 这是一个便捷方法,它将输入作为参数接收,返回字符串作为输出。这个方法只能被部分Chain使用,并且其返回结果不如__call__丰富。
  • 抽象方法

    • input_keys(self) -> List[str]:此方法应返回链输入中预期的键。
    • output_keys(self) -> List[str]:此方法应返回链输出中预期的键。
    • _call:此方法执行链。它是一个私有方法,不面向用户。只在 Chain.call 内部调用,该方法是处理回调配置和一些输入/输出处理的用户面向方法。
    • _chain_type(self) -> str:返回chain的类型名称。

    非抽象方法

    辅助函数
    • _validate_inputs(self, inputs: Dict[str, Any]) -> None 和 _validate_outputs(self, outputs: Dict[str, Any]) -> None: 这两个方法分别检查所有的输入和输出是否存在,如果有任何缺失的键,则抛出一个错误。
    • _run_output_key(self) -> str:为单字符串输出执行链的便利方法。如果链具有多个输出或非字符串输出,则抛出 ValueError。
        def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
            """检查传入的 inputs 字典中的所有键是否都存在于 self.input_keys 中。如果有缺失的键,就抛出一个 ValueError 异常,并提示缺少哪些键"""
            missing_keys = set(self.input_keys).difference(inputs)
            if missing_keys:
                raise ValueError(f"Missing some input keys: {missing_keys}")
    
        def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
            """检查传入的 outputs 字典中的所有键是否都存在于 self.output_keys 中。如果有缺失的键,也会抛出一个 ValueError 异常,并提示缺少哪些键"""
            missing_keys = set(self.output_keys).difference(outputs)
            if missing_keys:
                raise ValueError(f"Missing some output keys: {missing_keys}")
    
    非辅助方法
    • prep_inputs 和 prep_outputs: 这两个方法负责验证和准备链的输入和输出,包括从内存中添加输入、保存运行信息到内存。
    • apply: 该方法在列表中的所有输入上调用链。
    • call:此方法执行链,并处理一些回调配置和输入/输出处理。
    • run:对单字符串输出执行链的便利方法,不同于 Chain.call 的是,它只适用于返回单个字符串输出的链。
    prep_inputs
        def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
            if not isinstance(inputs, dict):
                _input_keys = set(self.input_keys)
                if self.memory is not None:
                    # If there are multiple input keys, but some get set by memory so that
                    # only one is not set, we can still figure out which key it is.
                    _input_keys = _input_keys.difference(self.memory.memory_variables)
                if len(_input_keys) != 1:
                    raise ValueError(
                        f"A single string input was passed in, but this chain expects "
                        f"multiple inputs ({_input_keys}). When a chain expects "
                        f"multiple inputs, please call it by passing in a dictionary, "
                        "eg `chain({'foo': 1, 'bar': 2})`"
                    )
                inputs = {list(_input_keys)[0]: inputs}
            if self.memory is not None:
                external_context = self.memory.load_memory_variables(inputs)
                inputs = dict(inputs, **external_context)
            self._validate_inputs(inputs)
            return inputs
    

    prep_inputs 方法被定义为一个公共方法,它负责验证和准备 Chain 类的输入。该方法会检查传入的参数,并根据需要从类内部的运行时存储(memory)获取附加输入。返回一个字典类型的数据,其中包含原始输入以及可能从内部运行时存储加载的附加输入。

    处理步骤如下:

  • 如果 inputs 不是字典类型,该方法会假定 Chain 只接受一个输入参数。在此情况下,会检查是否有来自内存的输入,若有则会更新 _input_keys。然后,如果 _input_keys 的长度大于1,则抛出 ValueError 异常。这是因为在这种情况下,Chain 需要多个输入参数,但仅收到了一个输入参数。
  • 接下来,将唯一的输入数据转化为字典类型的结构。
  • 如果 Chain 类具有内存功能,那么它会从内存中加载输入数据并添加到输入字典中,以确保所有必要的输入都已准备就绪。
  • 然后,调用 _validate_inputs 方法来校验输入参数是否完整。这个方法会检查所有预期的键是否存在于输入字典中,如果有任何缺失的键,则抛出一个错误。
  • 最后,返回处理后的输入字典。
  • 总体来说,prep_inputs 方法是一个重要的函数,它负责对输入数据进行预处理和验证,以确保其满足 Chain 需要的格式和条件,为后续的链式操作提供了准备。

    prep_outputs
        def prep_outputs(
            self,
            inputs: Dict[str, str],
            outputs: Dict[str, str],
            return_only_outputs: bool = False,
        ) -> Dict[str, str]:
            self._validate_outputs(outputs)
            if self.memory is not None:
                self.memory.save_context(inputs, outputs)
            if return_only_outputs:
                return outputs
            else:
                return {**inputs, **outputs}
    

    prep_outputs 方法是一个在 Chain 类中定义的方法,其主要职责是验证和准备链的输出数据。除此之外,它还会将执行过程中的相关信息保存到类内部的运行时存储(内存)。

    参数:

    • inputs:这应是包含原始输入的字典类型的数据,可能包括从内存加载的附加输入。
    • outputs:这是包含初始链输出的字典,需要被进一步处理和验证。
    • return_only_outputs:这是一个布尔类型的标志,用来指示是否只返回输出数据。如果设为 False,则输入也将添加到最终的输出中。

    返回值:

    • 返回一个字典,其中包括所有经过处理和验证的输出,根据 return_only_outputs 参数的设定,可能还包括输入数据。

    处理步骤如下:

  • 首先,调用 _validate_outputs 方法对 outputs 参数进行校验,确保所有的输出数据都是完整和有效的。
  • 如果 Chain 类具有内存功能,那么会把 inputs 和 outputs 的上下文信息保存到内存中。这样可以让我们在后续的执行或调试中回溯这次的运行过程。
  • 根据 return_only_outputs 参数决定返回值:如果为 True,那么只返回 outputs;否则,将输入和输出合并后一起返回。
  • 总结来说,prep_outputs 方法在 Chain 类的执行流程中起到了关键作用。它不仅确保了输出数据满足预设条件和格式,还负责记录执行过程,使得后续的调试和分析更加方便。

    call
        def __call__(
            self,
            inputs: Union[Dict[str, Any], Any],
            return_only_outputs: bool = False,
            callbacks: Callbacks = None,
            *,
            tags: Optional[List[str]] = None,
            metadata: Optional[Dict[str, Any]] = None,
            include_run_info: bool = False,
        ) -> Dict[str, Any]:
            """Execute the chain."""
            inputs = self.prep_inputs(inputs)
            callback_manager = CallbackManager.configure(
                callbacks,
                self.callbacks,
                self.verbose,
                tags,
                self.tags,
                metadata,
                self.metadata,
            )
            new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
            # llm 调用之前的回调
            run_manager = callback_manager.on_chain_start(
                dumpd(self),
                inputs,
            )
            try:
                outputs = (
                    self._call(inputs, run_manager=run_manager)
                    if new_arg_supported
                    else self._call(inputs)
                )
            except (KeyboardInterrupt, Exception) as e:
                run_manager.on_chain_error(e)
                raise e
            # llm调用结束后的回调
            run_manager.on_chain_end(outputs)
            final_outputs: Dict[str, Any] = self.prep_outputs(
                inputs, outputs, return_only_outputs
            )
            if include_run_info:
                final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
            return final_outputs
    

    这段代码定义了 call 方法,使得 Chain 类实例可以像函数一样被调用。该方法是链对象执行流程的主要入口点。

    参数:

    • inputs: 预期为包含输入数据的字典或单一值(如果Chain类只需要一个参数)。字典的键应与 Chain.input_keys 对应,除非预期的输入将从链的内部存储加载。
    • return_only_outputs: 布尔值,指示返回值是否只包括输出。如果设为True,只有由链生成的新键值对会被返回;如果为False,原始输入和由链生成的新键值对都将被返回。默认为False。
    • callbacks: 回调函数列表,它们将在此次链运行过程中调用,作为构建时指定的回调函数之外的额外操作。
    • tags: 字符串列表,用于对链的执行进行标记。这些标签会传递给所有回调函数,并且只有运行时指定的标签会传播到其他对象的调用中。
    • metadata: 可选参数,用于存储与链相关联的额外信息。默认为None。
    • include_run_info: 布尔值,决定是否在返回的结果中包含此次运行的相关信息。默认为False。

    返回:

    • 一个包含预期输出的字典,其键应存在于 Chain.output_keys 中定义的键集合。

    操作步骤如下:

  • 将输入数据进行预处理,以确保它们满足链的实际运行需求。
  • 配置回调管理器,以便在链的运行过程中执行特定的操作。
  • 使用已配置的回调管理器启动链的运行。
  • 尝试执行 _call 方法(这是在各个子类中根据具体需求定义的方法),以实现链的实际操作。如果在运行过程中出现任何异常,将由回调管理器捕获并处理。
  • 在链运行结束后,通过回调管理器处理关闭事件,例如清理资源,记录日志等。
  • 对链运行生成的输出进行预处理,以满足返回格式的要求。
  • 如果 include_run_info 设为True,则在最终的输出字典中添加此次运行的相关信息。
  • 返回处理后的输出字典。
  • 总体来说,call 方法是 Chain 类实例执行流程的核心控制点。它负责处理输入,启动链的运行,处理任何可能出现的异常,最后返回经过预处理的结果。这个函数的设计使得我们可以灵活地控制和调整链的执行过程。

    __call__方法的流程图如下:

    添加图片注释,不超过 140 字(可选)

    call 方法的时序图如下:

    添加图片注释,不超过 140 字(可选)

    LLMChain

    LLMChain 是一个专门用于执行语言学习模型(LLM)查询的类,它是在基础的 Chain 类上进行拓展实现的。LLMChain 类提供了一种方式将语言模型与特定的工作流整合起来,使得执行语言模型查询变得更加简单和可控。这个类的设计通常用于构建复杂的自然语言处理应用,如对话系统、问答系统等。

    主要属性如下:

    • prompt: 该属性存储一个对象,负责根据给定的输入生成相应的提示。
    • llm:这是实际执行查询的语言模型对象。
    • output_key: 字符串,定义了输出字典中结果的键,默认为"text"。
    • output_parser: 这是用来解析模型输出的对象,默认使用 NoOpOutputParser,它只返回最可能的字符串,而不会做任何其他转换或处理。
    • return_final_only: 布尔值,决定是否仅返回最终解析结果,默认为True。若为false,则会返回关于生成过程的详细信息。
    • llm_kwargs: 这是一个字典,包含传递给语言学习模型的参数。

    重要的方法包括:

    • _call(): 该方法调用 generate 方法,并从其响应中创建输出。
    • generate(): 该方法根据输入生成语言学习模型的结果。
    • prep_prompts(): 这个函数负责根据输入准备相应的提示。
    • apply(): 此方法使用语言学习模型的 generate 方法产生结果,以实现高效运行。
    • create_outputs(): 从模型响应中创建并格式化输出。

    类方法 from_string() 接受一个语言学习模型和一个模板字符串作为参数,然后返回一个配置好的 LLMChain 实例。

    相比Chain的改动

    LLMChain 是从基类 Chain 继承而来的子类,它重写了一些父类的抽象方法,并新增了一些专门用于处理语言学习模型(Language Learning Model,LLM)的方法。

    主要变动和新特性如下:

    • input_keys 和 output_keys: 这两个属性是 LLMChain 从 Chain 类中继承并重写的。它们定义了链对象期待的输入类型和输出类型,以便正确地进行数据处理和返回结果。
    • _call(): 这是从 Chain 类中覆盖的方法,它是执行链对象的核心接口。在 LLMChain 中,该方法被定制化以适应特定的语言学习模型调用方式。
    • generate(): 这是 LLMChain 类新增的方法,它负责根据给定的输入数据生成语言学习模型的结果。
    • prep_prompts(): 新增的方法,用于根据输入准备相应的提示。
    • create_outputs(): 另一个新增方法,负责根据模型的响应创建格式化的输出。
    • _parse_generation: 这是一个辅助方法,主要用于解析 generate() 方法产生的结果。
    • from_string(): 这是一个类方法,设计为外部调用接口。它接受一个语言学习模型和一个模板字符串作为参数,然后返回一个配置好的 LLMChain 实例。

    通过这些改动和新增特性,LLMChain 提供了一种更适合处理语言学习任务的链对象实现方式,以便简化复杂的自然语言处理过程。

    input_keys 和 output_keys

        @property
        def input_keys(self) -> List[str]:
            return self.prompt.input_variables
    
        @property
        def output_keys(self) -> List[str]:
            if self.return_final_only:
                return [self.output_key]
            else:
                return [self.output_key, "full_generation"]
    
    • input_keys:这个属性返回self.prompt.input_variables的值
    • output_keys:如果return_final_only为真,则只返回output_key 属性;否则,它将返回output_key以及"full_generation"这样的字符串。这意味着,如果设置了return_final_only=true,那么只会返回最终的输出结果,否则会返回完整的生成过程用于调试和分析。

    prep_prompts

    这个函数的主要目的是根据输入列表创建语言模型的提示列表,同时检查所有输入中的"stop"值是否相同。

    def prep_prompts(
            self,
            input_list: List[Dict[str, Any]],
            run_manager: Optional[CallbackManagerForChainRun] = None,
        ) -> Tuple[List[PromptValue], Optional[List[str]]]:
            """Prepare prompts from inputs."""
            stop = None
            if "stop" in input_list[0]:
                stop = input_list[0]["stop"]
            prompts = []
            for inputs in input_list:
                selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
                prompt = self.prompt.format_prompt(**selected_inputs)
                _colored_text = get_colored_text(prompt.to_string(), "green")
                _text = "Prompt after formatting:n" + _colored_text
                if run_manager:
                    run_manager.on_text(_text, end="n", verbose=self.verbose)
                if "stop" in inputs and inputs["stop"] != stop:
                    raise ValueError(
                        "If `stop` is present in any inputs, should be present in all."
                    )
                prompts.append(prompt)
            return prompts, stop
    

    create_outputs

    create_outputs是一个转换函数,其主要目的是将语言学习模型(LLM)生成的结果 (LLMResult) 格式化为期望的输出格式。

        def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
            """Create outputs from response."""
            result = [
                # Get the text of the top generated string.
                {
                    self.output_key: self.output_parser.parse_result(generation),
                    "full_generation": generation,
                }
                for generation in llm_result.generations
            ]
            if self.return_final_only:
                result = [{self.output_key: r[self.output_key]} for r in result]
            return result
    

    _call

    添加图片注释,不超过 140 字(可选)

    _call方法首先调用generate方法,然后在generate方法中,它会调用prep_prompts来准备提示,然后使用llm.generate_prompt生成结果。最后,_call使用create_outputs创建输出,返回输出列表的第一个元素。

    ConversationChain

    ConversationChain是LLMChain的子类,用于进行对话和从内存中加载上下文。

    class ConversationChain(LLMChain):
        """Chain to have a conversation and load context from memory.
    
        Example:
            .. code-block:: python
    
                from langchain import ConversationChain, OpenAI
    
                conversation = ConversationChain(llm=OpenAI())
        """
    
        memory: BaseMemory = Field(default_factory=ConversationBufferMemory)
        """Default memory store."""
        prompt: BasePromptTemplate = PROMPT
        """Default conversation prompt to use."""
    
        input_key: str = "input"  #: :meta private:
        output_key: str = "response"  #: :meta private:
    
        class Config:
            """Configuration for this pydantic object."""
    
            extra = Extra.forbid
            arbitrary_types_allowed = True
    
        @property
        def input_keys(self) -> List[str]:
            """Use this since so some prompt vars come from history."""
            return [self.input_key]
    
        @root_validator()
        def validate_prompt_input_variables(cls, values: Dict) -> Dict:
            """Validate that prompt input variables are consistent."""
            memory_keys = values["memory"].memory_variables
            input_key = values["input_key"]
            if input_key in memory_keys:
                raise ValueError(
                    f"The input key {input_key} was also found in the memory keys "
                    f"({memory_keys}) - please provide keys that don't overlap."
                )
            prompt_variables = values["prompt"].input_variables
            expected_keys = memory_keys + [input_key]
            if set(expected_keys) != set(prompt_variables):
                raise ValueError(
                    "Got unexpected prompt input variables. The prompt expects "
                    f"{prompt_variables}, but got {memory_keys} as inputs from "
                    f"memory, and {input_key} as the normal input key."
                )
            return values
    

    根据 Pydantic 库文档,root_validator 装饰器定义的方法将在所有字段被验证后但在预和后钩子之前运行。这样就可以进行跨字段验证。根验证器接收一个包含已解析字段值的字典,并返回一个更新的版本。

    validate_prompt_input_variables 验证器的职责是确保内存键(memory_keys)和输入键(input_key)与提示变量(prompt_variables)一致。它也确保了输入键不会与内存键重叠。如果存在不匹配或者重叠,它将抛出一个 ValueError,指明错误的情况。

    默认的prompt 接收两个键值对参数:history 和 input。

    • history 代表了过去状态或行为的记忆,它是由 memory 对象提供的。memory 通常包含一系列先前的事件或交互,用来帮助模型生成响应或执行某些操作。例如,在一个聊天机器人应用中,history 可能包括以前的对话历史,以便生成上下文相关的回答。
    • input 是当前要处理的数据或请求。这可能是用户输入的一段文本,也可能是其他类型的数据,取决于实际应用的需求。

    通过这种方式,prompt 对象得以将过去的历史信息(history)和当前的输入信息(input)结合起来,生成针对特定上下文的语言模型查询。

    DEFAULT_TEMPLATE = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
    
    Current conversation:
    {history}
    Human: {input}
    AI:"""
    PROMPT = PromptTemplate(input_variables=["history", "input"], template=DEFAULT_TEMPLATE)
    

    ConversationChain 设定了各个变量的默认值。主要过程可以分解为几个步骤。首先,在 pre_inputs过程中,memory 对象会被加载,并与 input 变量共同构成 prompt 的输入。然后,这个输入被传递给 prompt 模板,执行 prep_prompts 方法生成 'PromptValue'。

    关于键值(key)的使用,我们有以下规则:input_key 的默认值是 'input',它在 input_keys(self) 方法中被使用。另外,input_keys 属性会在 pre_inputs 方法中被调用。output_key 则被用在 output_keys 方法中,用于指定输出字典的 key。

    相关文章

    JavaScript2024新功能:Object.groupBy、正则表达式v标志
    PHP trim 函数对多字节字符的使用和限制
    新函数 json_validate() 、randomizer 类扩展…20 个PHP 8.3 新特性全面解析
    使用HTMX为WordPress增效:如何在不使用复杂框架的情况下增强平台功能
    为React 19做准备:WordPress 6.6用户指南
    如何删除WordPress中的所有评论

    发布评论