"""FlyteKit回调处理程序。"""
from __future__ import annotations
import logging
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
from langchain_community.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
import_pandas,
import_spacy,
import_textstat,
)
if TYPE_CHECKING:
import flytekit
from flytekitplugins.deck import renderer
logger = logging.getLogger(__name__)
[docs]def import_flytekit() -> Tuple[flytekit, renderer]:
"""导入 flytekit 和 flytekitplugins-deck-standard。"""
return (
guard_import("flytekit"),
guard_import(
"flytekitplugins.deck", pip_name="flytekitplugins-deck-standard"
).renderer,
)
[docs]def analyze_text(
text: str,
nlp: Any = None,
textstat: Any = None,
) -> dict:
"""使用textstat和spacy分析文本。
参数:
text(str):要分析的文本。
nlp(spacy.lang):用于可视化的spacy语言模型。
返回:
(dict):包含复杂度指标和序列化为HTML字符串的可视化文件的字典。
"""
resp: Dict[str, Any] = {}
if textstat is not None:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
if nlp is not None:
spacy = import_spacy()
doc = nlp(text)
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
text_visualizations = {
"dependency_tree": dep_out,
"entities": ent_out,
}
resp.update(text_visualizations)
return resp
[docs]class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""在Flyte任务中使用的回调处理程序。"""
[docs] def __init__(self) -> None:
"""初始化回调处理程序。"""
flytekit, renderer = import_flytekit()
self.pandas = import_pandas()
self.textstat = None
try:
self.textstat = import_textstat()
except ImportError:
logger.warning(
"Textstat library is not installed. \
It may result in the inability to log \
certain metrics that can be captured with Textstat."
)
spacy = None
try:
spacy = import_spacy()
except ImportError:
logger.warning(
"Spacy library is not installed. \
It may result in the inability to log \
certain metrics that can be captured with Spacy."
)
super().__init__()
self.nlp = None
if spacy:
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.warning(
"FlyteCallbackHandler uses spacy's en_core_web_sm model"
" for certain metrics. To download,"
" run the following command in your terminal:"
" `python -m spacy download en_core_web_sm`"
)
self.table_renderer = renderer.TableRenderer
self.markdown_renderer = renderer.MarkdownRenderer
self.deck = flytekit.Deck(
"LangChain Metrics",
self.markdown_renderer().to_html("## LangChain Metrics"),
)
[docs] def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""LLM启动时运行。"""
self.step += 1
self.llm_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
prompt_responses = []
for prompt in prompts:
prompt_responses.append(prompt)
resp.update({"prompts": prompt_responses})
self.deck.append(self.markdown_renderer().to_html("### LLM Start"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""当LLM生成一个新的令牌时运行。"""
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""LLM 运行结束时运行。"""
self.step += 1
self.llm_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### LLM End"))
self.deck.append(self.table_renderer().to_html(self.pandas.DataFrame([resp])))
for generations in response.generations:
for generation in generations:
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
if self.nlp or self.textstat:
generation_resp.update(
analyze_text(
generation.text, nlp=self.nlp, textstat=self.textstat
)
)
complexity_metrics: Dict[str, float] = generation_resp.pop(
"text_complexity_metrics"
)
self.deck.append(
self.markdown_renderer().to_html("#### Text Complexity Metrics")
)
self.deck.append(
self.table_renderer().to_html(
self.pandas.DataFrame([complexity_metrics])
)
+ "\n"
)
dependency_tree = generation_resp["dependency_tree"]
self.deck.append(
self.markdown_renderer().to_html("#### Dependency Tree")
)
self.deck.append(dependency_tree)
entities = generation_resp["entities"]
self.deck.append(self.markdown_renderer().to_html("#### Entities"))
self.deck.append(entities)
else:
self.deck.append(
self.markdown_renderer().to_html("#### Generated Response")
)
self.deck.append(self.markdown_renderer().to_html(generation.text))
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""当LLM出现错误时运行。"""
self.step += 1
self.errors += 1
[docs] def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""当链开始运行时运行。"""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input
self.deck.append(self.markdown_renderer().to_html("### Chain Start"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([input_resp])) + "\n"
)
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""当链结束运行时运行。"""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Chain End"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""当链式错误时运行。"""
self.step += 1
self.errors += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None:
"""
当代理程序结束时运行。
"""
self.step += 1
self.text_ctr += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_text", "text": text})
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### On Text"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""当代理程序运行结束时运行。"""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Agent Finish"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""在代理程序上运行的操作。"""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Agent Action"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)