Source code for langchain.model_laboratory
"""尝试不同的模型。"""
from __future__ import annotations
from typing import List, Optional, Sequence
from langchain_core.language_models.llms import BaseLLM
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils.input import get_color_mapping, print_text
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
[docs]class ModelLaboratory:
"""尝试不同的模型。"""
[docs] def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None):
"""初始化链以进行实验。
参数:
chains:要进行实验的链列表。
"""
for chain in chains:
if not isinstance(chain, Chain):
raise ValueError(
"ModelLaboratory should now be initialized with Chains. "
"If you want to initialize with LLMs, use the `from_llms` method "
"instead (`ModelLaboratory.from_llms(...)`)"
)
if len(chain.input_keys) != 1:
raise ValueError(
"Currently only support chains with one input variable, "
f"got {chain.input_keys}"
)
if len(chain.output_keys) != 1:
raise ValueError(
"Currently only support chains with one output variable, "
f"got {chain.output_keys}"
)
if names is not None:
if len(names) != len(chains):
raise ValueError("Length of chains does not match length of names.")
self.chains = chains
chain_range = [str(i) for i in range(len(self.chains))]
self.chain_colors = get_color_mapping(chain_range)
self.names = names
[docs] @classmethod
def from_llms(
cls, llms: List[BaseLLM], prompt: Optional[PromptTemplate] = None
) -> ModelLaboratory:
"""使用LLMs进行实验的初始化和可选提示。
参数:
llms:要进行实验的LLMs列表
prompt:要用于提示LLMs的可选提示。默认为None。
如果提供了提示,则应该只有一个输入变量。
"""
if prompt is None:
prompt = PromptTemplate(input_variables=["_input"], template="{_input}")
chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms]
names = [str(llm) for llm in llms]
return cls(chains, names=names)
[docs] def compare(self, text: str) -> None:
"""比较模型在输入文本上的输出。
如果在启动实验室时提供了提示,则该文本将被输入到提示中。如果没有提供提示,则输入文本就是整个提示。
参数:
text:要在所有模型上运行的输入文本。
"""
print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201
for i, chain in enumerate(self.chains):
if self.names is not None:
name = self.names[i]
else:
name = str(chain)
print_text(name, end="\n")
output = chain.run(text)
print_text(output, color=self.chain_colors[str(i)], end="\n\n")