什么是chains
Langchain是一个由多种组件构成的系统,其中每个组件都有其特定的职能。这里的'Chain',如其名所示,就像一条串联这些组件的链条,使我们能够实现更复杂的功能。在这个系统中,组件可能包括但不限于prompt、memory、retrieval和model,甚至可以是chain本身。以Prompt、Memory、Model为例,我们可以组合这三个组件,创建一个应用程序,该程序能接收用户输入,通过大型模型获取反馈,并将反馈记录在Memory组件中,作为下一次输入的参数。
此外,Langchain还内置了各种类型的chains,满足不同的需求。当然,如果它们不能满足你的特定需求,Langchain也支持自定义chain,以便更好地适应你的业务场景。
快速开始
LLMChain 是一个最简单的链,它接受一个提示模板,使用用户输入对其进行格式化,并从 LLM 返回响应。要使用 LLMChain,我们首先要创建一个提示模板,这个模块接受一个水果的输入变量,用于描述对应水果的特征。
from langchain.prompts import PromptTemplate
prompt = PromptTemplate(
input_variables=["fruit"],
template="Please describe the characteristics of the following fruits {fruit}?",
)
接着我们在创建一个Language Model (LLM), 并和prompt组合构造一个最简单的Chain。要使用这个Chain,我们只需调用'run'方法即可。在这个过程中,LLMChain会帮助我们格式化Prompt字符串,然后将其用作输入,调用LLM生成结果。这样,我们能够轻松地获得对应水果特征的描述。
from langchain.llms import OpenAI
from langchain.chains import LLMChain
llm = OpenAI(temperature=0.9)
chain = LLMChain(llm=llm, prompt=prompt)
# Bananas are a curved, yellow fruit with a thick, edible peel. They are a popular snack and are often eaten raw. Bananas are a good source of potassium, fiber, and vitamin B6. They are also low in calories and fat, making them a healthy snack. Bananas are also used in baking, smoothies, and other recipes.
print(chain.run("banana")) # 等价于 chain.run({"fruit": "banana"})
只有在 Chain 的 output_keys 只有一个元素时,才可以使用 run 方法。另外对于只有一个输入变量的情况下,才直接输入字符串,而无需指定输入dict类型,否则需要使用dict类型。
Langchain的chains模块不仅支持使用run方法,还提供了__call__方法。不同于run方法返回的是字符串,当我们调用__call__时,它将返回一个字典,其中包含输入参数。然而,如果我们希望结果中不包括输入参数,可以在调用__call__方法时指定 return_only_outputs=True。这样,在保证功能丰富的同时,我们能更灵活地控制输出的内容。
# {'fruit': 'banana', 'text': 'nnBananas are a tropical fruit that are curved in shape and have a yellow peel when ripe. They are sweet and creamy in taste and have a soft, starchy texture. Bananas are a good source of potassium, fiber, and vitamin B6. They are also low in calories and fat.'}
print(chain("banana"))
# {'text': 'nnBananas are a sweet, yellow fruit with a curved shape and a soft, creamy texture. They have a thick peel that can be easily peeled off to reveal the soft, sweet flesh inside. Bananas are a great source of potassium, fiber, and vitamin B6. They are also low in calories and fat, making them a healthy snack.'}
print(chain("banana", return_only_outputs=True))
我们也可以上面使用的Language Model 换成聊天模型:
from langchain.prompts.chat import HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
template="Please describe the characteristics of the following fruits {fruit}?",
input_variables=["fruit"],
)
)
prompt = ChatPromptTemplate.from_messages([human_message_prompt])
llm = ChatOpenAI(temperature=0.1)
chain = LLMChain(llm=llm, prompt=prompt)
"""
Bananas are a popular fruit known for their distinctive shape, vibrant yellow color, and sweet taste. Here are some characteristics of bananas:
1. Shape and Size: Bananas are elongated and curved, resembling a crescent moon. They typically measure around 6 to 8 inches in length, but can vary in size.
2. Color: When ripe, bananas have a bright yellow peel. However, they can also be green when unripe or develop brown spots as they become overripe.
3. Texture: The texture of a banana is smooth and creamy when ripe, making it easy to eat. The flesh is soft and tender, with a slight firmness.
Overall, bananas are a delicious and nutritious fruit that is loved for their convenience, versatility, and sweet taste.
"""
print(chain.run("banana"))
添加缓存
Langchain的chains模块支持使用'Memory'参数,该参数必须是'BaseMemory'类型。一旦添加了这个参数,chain对象便能在多次调用中持久化并保持数据。简单地说,通过引入'Memory'参数使得chain对象从无状态变为有状态,意味着它可以储存和跨时间维护信息。
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationChain, LLMChain
from langchain.memory import ConversationBufferMemory
llm = ChatOpenAI(temperature=0.1)
conversation = ConversationChain(
llm=llm,
memory=ConversationBufferMemory()
)
print(conversation.run("Answer briefly. What are the first 3 colors of a rainbow?"))
# -> The first three colors of a rainbow are red, orange, and yellow.
print(conversation.run("And the next 4?"))
# -> The next four colors of a rainbow are green, blue, indigo, and violet.
源码分析
在本次的源码分析中,我们将专注于探索Langchain中三个核心类:Chain、LLMChain和ConversationChain。每个类都在实现链式处理逻辑中起到关键作用。具体来说,
- Chain是一个基类,提供了公共的接口和一些默认行为。
- LLMChain是Chain的子类,特点是利用语言模型 (Language Model, 缩写为LLM) 进行工作。
- ConversationChain则用于管理和执行基于会话的交互。
请注意,这里的分析仅限于一些关键的同步方法,异步方法等其它内容将在后续文章中介绍。此外,其他类型如ConversationalRetrievalChain、RouterChain等也将在后续文章中进行深入研究。
添加图片注释,不超过 140 字(可选)
Chain
Chain位于langchain.chains目录下的base.py文件中,从文件命名就可以看出这是一个基类,其中包含了一些关键属性和方法。对于属性,包含memory、callbacks、verbose等。而对于方法,既有抽象方法,又有具体实现的方法。抽象方法定义了所有派生Chain类必须遵循的接口,具体实现的方法为所有派生类提供了通用的功能。
总的来说,这个基类为Langchain创建了一个灵活的架构,使得开发者能够通过创建新的Chain子类来快速扩展和自定义功能。
属性
在langchain.chains的base.py文件中定义的Chain类有几个关键属性值。分别为memory、callbacks、callback_manager、verbose、tags和metadata。
总结来说,这些属性让Chain类具有了更多灵活性并能够支持更复杂的操作,包括内存管理、回调处理,以及详细模式的控制等等。
方法
Chain接口设计得易于创建具有以下特性的应用程序:
Chain类主要暴露出两种方法:
抽象方法
- input_keys(self) -> List[str]:此方法应返回链输入中预期的键。
- output_keys(self) -> List[str]:此方法应返回链输出中预期的键。
- _call:此方法执行链。它是一个私有方法,不面向用户。只在 Chain.call 内部调用,该方法是处理回调配置和一些输入/输出处理的用户面向方法。
- _chain_type(self) -> str:返回chain的类型名称。
非抽象方法
辅助函数
- _validate_inputs(self, inputs: Dict[str, Any]) -> None 和 _validate_outputs(self, outputs: Dict[str, Any]) -> None: 这两个方法分别检查所有的输入和输出是否存在,如果有任何缺失的键,则抛出一个错误。
- _run_output_key(self) -> str:为单字符串输出执行链的便利方法。如果链具有多个输出或非字符串输出,则抛出 ValueError。
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""检查传入的 inputs 字典中的所有键是否都存在于 self.input_keys 中。如果有缺失的键,就抛出一个 ValueError 异常,并提示缺少哪些键"""
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
"""检查传入的 outputs 字典中的所有键是否都存在于 self.output_keys 中。如果有缺失的键,也会抛出一个 ValueError 异常,并提示缺少哪些键"""
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
非辅助方法
- prep_inputs 和 prep_outputs: 这两个方法负责验证和准备链的输入和输出,包括从内存中添加输入、保存运行信息到内存。
- apply: 该方法在列表中的所有输入上调用链。
- call:此方法执行链,并处理一些回调配置和输入/输出处理。
- run:对单字符串输出执行链的便利方法,不同于 Chain.call 的是,它只适用于返回单个字符串输出的链。
prep_inputs
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
prep_inputs 方法被定义为一个公共方法,它负责验证和准备 Chain 类的输入。该方法会检查传入的参数,并根据需要从类内部的运行时存储(memory)获取附加输入。返回一个字典类型的数据,其中包含原始输入以及可能从内部运行时存储加载的附加输入。
处理步骤如下:
总体来说,prep_inputs 方法是一个重要的函数,它负责对输入数据进行预处理和验证,以确保其满足 Chain 需要的格式和条件,为后续的链式操作提供了准备。
prep_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
prep_outputs 方法是一个在 Chain 类中定义的方法,其主要职责是验证和准备链的输出数据。除此之外,它还会将执行过程中的相关信息保存到类内部的运行时存储(内存)。
参数:
- inputs:这应是包含原始输入的字典类型的数据,可能包括从内存加载的附加输入。
- outputs:这是包含初始链输出的字典,需要被进一步处理和验证。
- return_only_outputs:这是一个布尔类型的标志,用来指示是否只返回输出数据。如果设为 False,则输入也将添加到最终的输出中。
返回值:
- 返回一个字典,其中包括所有经过处理和验证的输出,根据 return_only_outputs 参数的设定,可能还包括输入数据。
处理步骤如下:
总结来说,prep_outputs 方法在 Chain 类的执行流程中起到了关键作用。它不仅确保了输出数据满足预设条件和格式,还负责记录执行过程,使得后续的调试和分析更加方便。
call
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Execute the chain."""
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
# llm 调用之前的回调
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
try:
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise e
# llm调用结束后的回调
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
这段代码定义了 call 方法,使得 Chain 类实例可以像函数一样被调用。该方法是链对象执行流程的主要入口点。
参数:
- inputs: 预期为包含输入数据的字典或单一值(如果Chain类只需要一个参数)。字典的键应与 Chain.input_keys 对应,除非预期的输入将从链的内部存储加载。
- return_only_outputs: 布尔值,指示返回值是否只包括输出。如果设为True,只有由链生成的新键值对会被返回;如果为False,原始输入和由链生成的新键值对都将被返回。默认为False。
- callbacks: 回调函数列表,它们将在此次链运行过程中调用,作为构建时指定的回调函数之外的额外操作。
- tags: 字符串列表,用于对链的执行进行标记。这些标签会传递给所有回调函数,并且只有运行时指定的标签会传播到其他对象的调用中。
- metadata: 可选参数,用于存储与链相关联的额外信息。默认为None。
- include_run_info: 布尔值,决定是否在返回的结果中包含此次运行的相关信息。默认为False。
返回:
- 一个包含预期输出的字典,其键应存在于 Chain.output_keys 中定义的键集合。
操作步骤如下:
总体来说,call 方法是 Chain 类实例执行流程的核心控制点。它负责处理输入,启动链的运行,处理任何可能出现的异常,最后返回经过预处理的结果。这个函数的设计使得我们可以灵活地控制和调整链的执行过程。
__call__方法的流程图如下:
添加图片注释,不超过 140 字(可选)
call 方法的时序图如下:
添加图片注释,不超过 140 字(可选)
LLMChain
LLMChain 是一个专门用于执行语言学习模型(LLM)查询的类,它是在基础的 Chain 类上进行拓展实现的。LLMChain 类提供了一种方式将语言模型与特定的工作流整合起来,使得执行语言模型查询变得更加简单和可控。这个类的设计通常用于构建复杂的自然语言处理应用,如对话系统、问答系统等。
主要属性如下:
- prompt: 该属性存储一个对象,负责根据给定的输入生成相应的提示。
- llm:这是实际执行查询的语言模型对象。
- output_key: 字符串,定义了输出字典中结果的键,默认为"text"。
- output_parser: 这是用来解析模型输出的对象,默认使用 NoOpOutputParser,它只返回最可能的字符串,而不会做任何其他转换或处理。
- return_final_only: 布尔值,决定是否仅返回最终解析结果,默认为True。若为false,则会返回关于生成过程的详细信息。
- llm_kwargs: 这是一个字典,包含传递给语言学习模型的参数。
重要的方法包括:
- _call(): 该方法调用 generate 方法,并从其响应中创建输出。
- generate(): 该方法根据输入生成语言学习模型的结果。
- prep_prompts(): 这个函数负责根据输入准备相应的提示。
- apply(): 此方法使用语言学习模型的 generate 方法产生结果,以实现高效运行。
- create_outputs(): 从模型响应中创建并格式化输出。
类方法 from_string() 接受一个语言学习模型和一个模板字符串作为参数,然后返回一个配置好的 LLMChain 实例。
相比Chain的改动
LLMChain 是从基类 Chain 继承而来的子类,它重写了一些父类的抽象方法,并新增了一些专门用于处理语言学习模型(Language Learning Model,LLM)的方法。
主要变动和新特性如下:
- input_keys 和 output_keys: 这两个属性是 LLMChain 从 Chain 类中继承并重写的。它们定义了链对象期待的输入类型和输出类型,以便正确地进行数据处理和返回结果。
- _call(): 这是从 Chain 类中覆盖的方法,它是执行链对象的核心接口。在 LLMChain 中,该方法被定制化以适应特定的语言学习模型调用方式。
- generate(): 这是 LLMChain 类新增的方法,它负责根据给定的输入数据生成语言学习模型的结果。
- prep_prompts(): 新增的方法,用于根据输入准备相应的提示。
- create_outputs(): 另一个新增方法,负责根据模型的响应创建格式化的输出。
- _parse_generation: 这是一个辅助方法,主要用于解析 generate() 方法产生的结果。
- from_string(): 这是一个类方法,设计为外部调用接口。它接受一个语言学习模型和一个模板字符串作为参数,然后返回一个配置好的 LLMChain 实例。
通过这些改动和新增特性,LLMChain 提供了一种更适合处理语言学习任务的链对象实现方式,以便简化复杂的自然语言处理过程。
input_keys 和 output_keys
@property
def input_keys(self) -> List[str]:
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
if self.return_final_only:
return [self.output_key]
else:
return [self.output_key, "full_generation"]
- input_keys:这个属性返回self.prompt.input_variables的值
- output_keys:如果return_final_only为真,则只返回output_key 属性;否则,它将返回output_key以及"full_generation"这样的字符串。这意味着,如果设置了return_final_only=true,那么只会返回最终的输出结果,否则会返回完整的生成过程用于调试和分析。
prep_prompts
这个函数的主要目的是根据输入列表创建语言模型的提示列表,同时检查所有输入中的"stop"值是否相同。
def prep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
create_outputs
create_outputs是一个转换函数,其主要目的是将语言学习模型(LLM)生成的结果 (LLMResult) 格式化为期望的输出格式。
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
"""Create outputs from response."""
result = [
# Get the text of the top generated string.
{
self.output_key: self.output_parser.parse_result(generation),
"full_generation": generation,
}
for generation in llm_result.generations
]
if self.return_final_only:
result = [{self.output_key: r[self.output_key]} for r in result]
return result
_call
添加图片注释,不超过 140 字(可选)
_call方法首先调用generate方法,然后在generate方法中,它会调用prep_prompts来准备提示,然后使用llm.generate_prompt生成结果。最后,_call使用create_outputs创建输出,返回输出列表的第一个元素。
ConversationChain
ConversationChain是LLMChain的子类,用于进行对话和从内存中加载上下文。
class ConversationChain(LLMChain):
"""Chain to have a conversation and load context from memory.
Example:
.. code-block:: python
from langchain import ConversationChain, OpenAI
conversation = ConversationChain(llm=OpenAI())
"""
memory: BaseMemory = Field(default_factory=ConversationBufferMemory)
"""Default memory store."""
prompt: BasePromptTemplate = PROMPT
"""Default conversation prompt to use."""
input_key: str = "input" #: :meta private:
output_key: str = "response" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Use this since so some prompt vars come from history."""
return [self.input_key]
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
memory_keys = values["memory"].memory_variables
input_key = values["input_key"]
if input_key in memory_keys:
raise ValueError(
f"The input key {input_key} was also found in the memory keys "
f"({memory_keys}) - please provide keys that don't overlap."
)
prompt_variables = values["prompt"].input_variables
expected_keys = memory_keys + [input_key]
if set(expected_keys) != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but got {memory_keys} as inputs from "
f"memory, and {input_key} as the normal input key."
)
return values
根据 Pydantic 库文档,root_validator 装饰器定义的方法将在所有字段被验证后但在预和后钩子之前运行。这样就可以进行跨字段验证。根验证器接收一个包含已解析字段值的字典,并返回一个更新的版本。
validate_prompt_input_variables 验证器的职责是确保内存键(memory_keys)和输入键(input_key)与提示变量(prompt_variables)一致。它也确保了输入键不会与内存键重叠。如果存在不匹配或者重叠,它将抛出一个 ValueError,指明错误的情况。
默认的prompt 接收两个键值对参数:history 和 input。
- history 代表了过去状态或行为的记忆,它是由 memory 对象提供的。memory 通常包含一系列先前的事件或交互,用来帮助模型生成响应或执行某些操作。例如,在一个聊天机器人应用中,history 可能包括以前的对话历史,以便生成上下文相关的回答。
- input 是当前要处理的数据或请求。这可能是用户输入的一段文本,也可能是其他类型的数据,取决于实际应用的需求。
通过这种方式,prompt 对象得以将过去的历史信息(history)和当前的输入信息(input)结合起来,生成针对特定上下文的语言模型查询。
DEFAULT_TEMPLATE = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
Current conversation:
{history}
Human: {input}
AI:"""
PROMPT = PromptTemplate(input_variables=["history", "input"], template=DEFAULT_TEMPLATE)
ConversationChain 设定了各个变量的默认值。主要过程可以分解为几个步骤。首先,在 pre_inputs过程中,memory 对象会被加载,并与 input 变量共同构成 prompt 的输入。然后,这个输入被传递给 prompt 模板,执行 prep_prompts 方法生成 'PromptValue'。
关于键值(key)的使用,我们有以下规则:input_key 的默认值是 'input',它在 input_keys(self) 方法中被使用。另外,input_keys 属性会在 pre_inputs 方法中被调用。output_key 则被用在 output_keys 方法中,用于指定输出字典的 key。