Source code for langchain_experimental.cpal.base

"""
CPAL链及其子链
"""
from __future__ import annotations

import json
from typing import Any, ClassVar, Dict, List, Optional, Type

from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.output_parsers import PydanticOutputParser
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.prompts.prompt import PromptTemplate

from langchain_experimental import pydantic_v1 as pydantic
from langchain_experimental.cpal.constants import Constant
from langchain_experimental.cpal.models import (
    CausalModel,
    InterventionModel,
    NarrativeModel,
    QueryModel,
    StoryModel,
)
from langchain_experimental.cpal.templates.univariate.causal import (
    template as causal_template,
)
from langchain_experimental.cpal.templates.univariate.intervention import (
    template as intervention_template,
)
from langchain_experimental.cpal.templates.univariate.narrative import (
    template as narrative_template,
)
from langchain_experimental.cpal.templates.univariate.query import (
    template as query_template,
)


class _BaseStoryElementChain(Chain):
    chain: LLMChain
    input_key: str = Constant.narrative_input.value  #: :meta private:
    output_key: str = Constant.chain_answer.value  #: :meta private:
    pydantic_model: ClassVar[
        Optional[Type[pydantic.BaseModel]]
    ] = None  #: :meta private:
    template: ClassVar[Optional[str]] = None  #: :meta private:

    @classmethod
    def parser(cls) -> PydanticOutputParser:
        """将LLM输出解析为一个pydantic对象。"""
        if cls.pydantic_model is None:
            raise NotImplementedError(
                f"pydantic_model not implemented for {cls.__name__}"
            )
        return PydanticOutputParser(pydantic_object=cls.pydantic_model)

    @property
    def input_keys(self) -> List[str]:
        """返回输入的键。

:元数据 私有:
"""
        return [self.input_key]

    @property
    def output_keys(self) -> List[str]:
        """返回输出键。

:元数据 私有:
"""
        _output_keys = [self.output_key]
        return _output_keys

    @classmethod
    def from_univariate_prompt(
        cls,
        llm: BaseLanguageModel,
        **kwargs: Any,
    ) -> Any:
        return cls(
            chain=LLMChain(
                llm=llm,
                prompt=PromptTemplate(
                    input_variables=[Constant.narrative_input.value],
                    template=kwargs.get("template", cls.template),
                    partial_variables={
                        "format_instructions": cls.parser().get_format_instructions()
                    },
                ),
            ),
            **kwargs,
        )

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        completion = self.chain.run(inputs[self.input_key])
        pydantic_data = self.__class__.parser().parse(completion)
        return {
            Constant.chain_data.value: pydantic_data,
            Constant.chain_answer.value: None,
        }


[docs]class NarrativeChain(_BaseStoryElementChain): """将叙述分解为其故事要素。 - 因果模型 - 查询 - 干预""" pydantic_model: ClassVar[Type[pydantic.BaseModel]] = NarrativeModel template: ClassVar[str] = narrative_template
[docs]class CausalChain(_BaseStoryElementChain): """将因果关系叙述转化为一系列操作。""" pydantic_model: ClassVar[Type[pydantic.BaseModel]] = CausalModel template: ClassVar[str] = causal_template
[docs]class InterventionChain(_BaseStoryElementChain): """设置因果模型的假设条件。""" pydantic_model: ClassVar[Type[pydantic.BaseModel]] = InterventionModel template: ClassVar[str] = intervention_template
[docs]class QueryChain(_BaseStoryElementChain): """使用SQL查询结果表。 *安全提示*:该类实现了一种生成SQL代码的AI技术。 如果执行这些SQL命令,务必确保它们使用的凭据范围受限,仅包括此链需要的权限。 如果未能这样做,可能会导致数据损坏或丢失,因为此链可能会尝试命令,如`DROP TABLE`或`INSERT`,如果适当提示。 防范这种负面结果的最佳方法是(在适当情况下)限制授予用于此链的凭据的权限。""" pydantic_model: ClassVar[Type[pydantic.BaseModel]] = QueryModel template: ClassVar[str] = query_template # TODO: incl. table schema
[docs]class CPALChain(_BaseStoryElementChain): """实施因果程序辅助语言(CPAL)链。 *安全提示*:此类的基本构件包括实现生成SQL代码的AI技术。如果执行这些SQL命令,关键是确保它们使用的凭据仅限于包括此链需要的权限。如果未能这样做,可能会导致数据损坏或丢失,因为此链可能会尝试命令,如`DROP TABLE`或`INSERT`,如果适当提示。防范这种负面结果的最佳方法是(视情况)限制授予用于此链的凭据的权限。""" llm: BaseLanguageModel narrative_chain: Optional[NarrativeChain] = None causal_chain: Optional[CausalChain] = None intervention_chain: Optional[InterventionChain] = None query_chain: Optional[QueryChain] = None _story: StoryModel = pydantic.PrivateAttr(default=None) # TODO: change name ?
[docs] @classmethod def from_univariate_prompt( cls, llm: BaseLanguageModel, **kwargs: Any, ) -> CPALChain: """实例化取决于组件链 *安全提示*:该类的构建模块包括实现生成SQL代码的AI技术。如果执行这些SQL命令,务必确保它们使用的凭据仅限于该链需要的权限。如果未能这样做,可能会导致数据损坏或丢失,因为该链可能会尝试执行诸如`DROP TABLE`或`INSERT`等命令。防范这种负面结果的最佳方法是(视情况)限制授予用于该链的凭据的权限。 """ return cls( llm=llm, chain=LLMChain( llm=llm, prompt=PromptTemplate( input_variables=["question", "query_result"], template=( "Summarize this answer '{query_result}' to this " "question '{question}'? " ), ), ), narrative_chain=NarrativeChain.from_univariate_prompt(llm=llm), causal_chain=CausalChain.from_univariate_prompt(llm=llm), intervention_chain=InterventionChain.from_univariate_prompt(llm=llm), query_chain=QueryChain.from_univariate_prompt(llm=llm), **kwargs, )
def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, **kwargs: Any, ) -> Dict[str, Any]: # instantiate component chains if self.narrative_chain is None: self.narrative_chain = NarrativeChain.from_univariate_prompt(llm=self.llm) if self.causal_chain is None: self.causal_chain = CausalChain.from_univariate_prompt(llm=self.llm) if self.intervention_chain is None: self.intervention_chain = InterventionChain.from_univariate_prompt( llm=self.llm ) if self.query_chain is None: self.query_chain = QueryChain.from_univariate_prompt(llm=self.llm) # decompose narrative into three causal story elements narrative = self.narrative_chain(inputs[Constant.narrative_input.value])[ Constant.chain_data.value ] story = StoryModel( causal_operations=self.causal_chain(narrative.story_plot)[ Constant.chain_data.value ], intervention=self.intervention_chain(narrative.story_hypothetical)[ Constant.chain_data.value ], query=self.query_chain(narrative.story_outcome_question)[ Constant.chain_data.value ], ) self._story = story def pretty_print_str(title: str, d: str) -> str: return title + "\n" + d _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager.on_text( pretty_print_str("story outcome data", story._outcome_table.to_string()), color="green", end="\n\n", verbose=self.verbose, ) def pretty_print_dict(title: str, d: dict) -> str: return title + "\n" + json.dumps(d, indent=4) _run_manager.on_text( pretty_print_dict("query data", story.query.dict()), color="blue", end="\n\n", verbose=self.verbose, ) if story.query._result_table.empty: # prevent piping bad data into subsequent chains raise ValueError( ( "unanswerable, query and outcome are incoherent\n" "\n" "outcome:\n" f"{story._outcome_table}\n" "query:\n" f"{story.query.dict()}" ) ) else: query_result = float(story.query._result_table.values[0][-1]) if False: """TODO: add this back in when demanded by composable chains""" reporting_chain = self.chain human_report = reporting_chain.run( question=story.query.question, query_result=query_result ) query_result = { "query_result": query_result, "human_report": human_report, } output = { Constant.chain_data.value: story, self.output_key: query_result, **kwargs, } return output
[docs] def draw(self, **kwargs: Any) -> None: """CPAL链可以绘制其生成的有向无环图。 在jupyter notebook中的使用: >>> from IPython.display import SVG >>> cpal_chain.draw(path="graph.svg") >>> SVG('graph.svg') """ self._story._networkx_wrapper.draw_graphviz(**kwargs)