"""
CPAL链及其子链
"""
from __future__ import annotations
import json
from typing import Any, ClassVar, Dict, List, Optional, Type
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.output_parsers import PydanticOutputParser
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.prompts.prompt import PromptTemplate
from langchain_experimental import pydantic_v1 as pydantic
from langchain_experimental.cpal.constants import Constant
from langchain_experimental.cpal.models import (
CausalModel,
InterventionModel,
NarrativeModel,
QueryModel,
StoryModel,
)
from langchain_experimental.cpal.templates.univariate.causal import (
template as causal_template,
)
from langchain_experimental.cpal.templates.univariate.intervention import (
template as intervention_template,
)
from langchain_experimental.cpal.templates.univariate.narrative import (
template as narrative_template,
)
from langchain_experimental.cpal.templates.univariate.query import (
template as query_template,
)
class _BaseStoryElementChain(Chain):
chain: LLMChain
input_key: str = Constant.narrative_input.value #: :meta private:
output_key: str = Constant.chain_answer.value #: :meta private:
pydantic_model: ClassVar[
Optional[Type[pydantic.BaseModel]]
] = None #: :meta private:
template: ClassVar[Optional[str]] = None #: :meta private:
@classmethod
def parser(cls) -> PydanticOutputParser:
"""将LLM输出解析为一个pydantic对象。"""
if cls.pydantic_model is None:
raise NotImplementedError(
f"pydantic_model not implemented for {cls.__name__}"
)
return PydanticOutputParser(pydantic_object=cls.pydantic_model)
@property
def input_keys(self) -> List[str]:
"""返回输入的键。
:元数据 私有:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""返回输出键。
:元数据 私有:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_univariate_prompt(
cls,
llm: BaseLanguageModel,
**kwargs: Any,
) -> Any:
return cls(
chain=LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=[Constant.narrative_input.value],
template=kwargs.get("template", cls.template),
partial_variables={
"format_instructions": cls.parser().get_format_instructions()
},
),
),
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
completion = self.chain.run(inputs[self.input_key])
pydantic_data = self.__class__.parser().parse(completion)
return {
Constant.chain_data.value: pydantic_data,
Constant.chain_answer.value: None,
}
[docs]class NarrativeChain(_BaseStoryElementChain):
"""将叙述分解为其故事要素。
- 因果模型
- 查询
- 干预"""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = NarrativeModel
template: ClassVar[str] = narrative_template
[docs]class CausalChain(_BaseStoryElementChain):
"""将因果关系叙述转化为一系列操作。"""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = CausalModel
template: ClassVar[str] = causal_template
[docs]class InterventionChain(_BaseStoryElementChain):
"""设置因果模型的假设条件。"""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = InterventionModel
template: ClassVar[str] = intervention_template
[docs]class QueryChain(_BaseStoryElementChain):
"""使用SQL查询结果表。
*安全提示*:该类实现了一种生成SQL代码的AI技术。
如果执行这些SQL命令,务必确保它们使用的凭据范围受限,仅包括此链需要的权限。
如果未能这样做,可能会导致数据损坏或丢失,因为此链可能会尝试命令,如`DROP TABLE`或`INSERT`,如果适当提示。
防范这种负面结果的最佳方法是(在适当情况下)限制授予用于此链的凭据的权限。"""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = QueryModel
template: ClassVar[str] = query_template # TODO: incl. table schema
[docs]class CPALChain(_BaseStoryElementChain):
"""实施因果程序辅助语言(CPAL)链。
*安全提示*:此类的基本构件包括实现生成SQL代码的AI技术。如果执行这些SQL命令,关键是确保它们使用的凭据仅限于包括此链需要的权限。如果未能这样做,可能会导致数据损坏或丢失,因为此链可能会尝试命令,如`DROP TABLE`或`INSERT`,如果适当提示。防范这种负面结果的最佳方法是(视情况)限制授予用于此链的凭据的权限。"""
llm: BaseLanguageModel
narrative_chain: Optional[NarrativeChain] = None
causal_chain: Optional[CausalChain] = None
intervention_chain: Optional[InterventionChain] = None
query_chain: Optional[QueryChain] = None
_story: StoryModel = pydantic.PrivateAttr(default=None) # TODO: change name ?
[docs] @classmethod
def from_univariate_prompt(
cls,
llm: BaseLanguageModel,
**kwargs: Any,
) -> CPALChain:
"""实例化取决于组件链
*安全提示*:该类的构建模块包括实现生成SQL代码的AI技术。如果执行这些SQL命令,务必确保它们使用的凭据仅限于该链需要的权限。如果未能这样做,可能会导致数据损坏或丢失,因为该链可能会尝试执行诸如`DROP TABLE`或`INSERT`等命令。防范这种负面结果的最佳方法是(视情况)限制授予用于该链的凭据的权限。
"""
return cls(
llm=llm,
chain=LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=["question", "query_result"],
template=(
"Summarize this answer '{query_result}' to this "
"question '{question}'? "
),
),
),
narrative_chain=NarrativeChain.from_univariate_prompt(llm=llm),
causal_chain=CausalChain.from_univariate_prompt(llm=llm),
intervention_chain=InterventionChain.from_univariate_prompt(llm=llm),
query_chain=QueryChain.from_univariate_prompt(llm=llm),
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs: Any,
) -> Dict[str, Any]:
# instantiate component chains
if self.narrative_chain is None:
self.narrative_chain = NarrativeChain.from_univariate_prompt(llm=self.llm)
if self.causal_chain is None:
self.causal_chain = CausalChain.from_univariate_prompt(llm=self.llm)
if self.intervention_chain is None:
self.intervention_chain = InterventionChain.from_univariate_prompt(
llm=self.llm
)
if self.query_chain is None:
self.query_chain = QueryChain.from_univariate_prompt(llm=self.llm)
# decompose narrative into three causal story elements
narrative = self.narrative_chain(inputs[Constant.narrative_input.value])[
Constant.chain_data.value
]
story = StoryModel(
causal_operations=self.causal_chain(narrative.story_plot)[
Constant.chain_data.value
],
intervention=self.intervention_chain(narrative.story_hypothetical)[
Constant.chain_data.value
],
query=self.query_chain(narrative.story_outcome_question)[
Constant.chain_data.value
],
)
self._story = story
def pretty_print_str(title: str, d: str) -> str:
return title + "\n" + d
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(
pretty_print_str("story outcome data", story._outcome_table.to_string()),
color="green",
end="\n\n",
verbose=self.verbose,
)
def pretty_print_dict(title: str, d: dict) -> str:
return title + "\n" + json.dumps(d, indent=4)
_run_manager.on_text(
pretty_print_dict("query data", story.query.dict()),
color="blue",
end="\n\n",
verbose=self.verbose,
)
if story.query._result_table.empty:
# prevent piping bad data into subsequent chains
raise ValueError(
(
"unanswerable, query and outcome are incoherent\n"
"\n"
"outcome:\n"
f"{story._outcome_table}\n"
"query:\n"
f"{story.query.dict()}"
)
)
else:
query_result = float(story.query._result_table.values[0][-1])
if False:
"""TODO: add this back in when demanded by composable chains"""
reporting_chain = self.chain
human_report = reporting_chain.run(
question=story.query.question, query_result=query_result
)
query_result = {
"query_result": query_result,
"human_report": human_report,
}
output = {
Constant.chain_data.value: story,
self.output_key: query_result,
**kwargs,
}
return output
[docs] def draw(self, **kwargs: Any) -> None:
"""CPAL链可以绘制其生成的有向无环图。
在jupyter notebook中的使用:
>>> from IPython.display import SVG
>>> cpal_chain.draw(path="graph.svg")
>>> SVG('graph.svg')
"""
self._story._networkx_wrapper.draw_graphviz(**kwargs)