Source code for langchain_experimental.tot.base

from __future__ import annotations

from textwrap import indent
from typing import Any, Dict, List, Optional, Type

from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain_core.callbacks.manager import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
)

from langchain_experimental.pydantic_v1 import Extra
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.controller import ToTController
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import Thought, ThoughtValidity
from langchain_experimental.tot.thought_generation import (
    BaseThoughtGenerationStrategy,
    ProposePromptStrategy,
)


[docs]class ToTChain(Chain): """ 实现思维树(ToT)的链。 """ llm: BaseLanguageModel """要使用的语言模型。必须设置为针对相同提示生成不同的变体。""" checker: ToTChecker """用于使用的ToT检查器。""" output_key: str = "response" #: :meta private: k: int = 10 """对话轮次的最大数量""" c: int = 3 """每个节点要探索的子节点数量""" tot_memory: ToTDFSMemory = ToTDFSMemory() tot_controller: ToTController = ToTController() tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy verbose_llm: bool = False class Config: """此pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True
[docs] @classmethod def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain: """从语言模型创建一个ToTChain。 :param llm: 要使用的语言模型。 :param kwargs: 传递给ToTChain构造函数的额外参数。 """ return cls(llm=llm, **kwargs)
def __init__(self, **kwargs: Any): super().__init__(**kwargs) self.tot_controller.c = self.c @property def input_keys(self) -> List[str]: """将是提示期望的任何键。 """ return ["problem_description"] @property def output_keys(self) -> List[str]: """将始终返回文本键。 :元数据 私有: """ return [self.output_key]
[docs] def log_thought( self, thought: Thought, level: int, run_manager: Optional[CallbackManagerForChainRun] = None, ) -> None: if run_manager: colors = { ThoughtValidity.VALID_FINAL: "green", ThoughtValidity.VALID_INTERMEDIATE: "yellow", ThoughtValidity.INVALID: "red", } text = indent(f"Thought: {thought.text}\n", prefix=" " * level) run_manager.on_text( text=text, color=colors[thought.validity], verbose=self.verbose )
def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() if run_manager: run_manager.on_text(text="Starting the ToT solve procedure.\n") problem_description = inputs["problem_description"] checker_inputs = {"problem_description": problem_description} thoughts_path: tuple[str, ...] = () thought_generator = self.tot_strategy_class( llm=self.llm, c=self.c, verbose=self.verbose_llm ) level = 0 for _ in range(self.k): level = self.tot_memory.level thought_text = thought_generator.next_thought( problem_description, thoughts_path, callbacks=_run_manager.get_child() ) checker_inputs["thoughts"] = thoughts_path + (thought_text,) thought_validity = self.checker( checker_inputs, callbacks=_run_manager.get_child() )["validity"] thought = Thought(text=thought_text, validity=thought_validity) if thought.validity == ThoughtValidity.VALID_FINAL: self.log_thought(thought, level, run_manager) return {self.output_key: thought.text} self.tot_memory.store(thought) self.log_thought(thought, level, run_manager) thoughts_path = self.tot_controller(self.tot_memory) return {self.output_key: "No solution found"} async def _acall( self, inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, str]: raise NotImplementedError("Async not implemented yet") @property def _chain_type(self) -> str: return "tot"