import os
import warnings
from typing import Any, Dict, List, Optional, cast
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from packaging.version import parse
[docs]class ArgillaCallbackHandler(BaseCallbackHandler):
"""回调处理程序,用于登录到Argilla。
参数:
dataset_name: Argilla中`FeedbackDataset`的名称。请注意,它必须
预先存在。如果您需要关于如何在Argilla中创建`FeedbackDataset`的帮助,
请访问
https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html。
workspace_name: Argilla中指定的工作区的名称,
`FeedbackDataset`所在的地方。默认为`None`,这意味着将使用
默认工作区。
api_url: 我们要使用的Argilla服务器的URL,以及
`FeedbackDataset`所在的地方。默认为`None`,这意味着将使用
`ARGILLA_API_URL`环境变量或默认值。
api_key: 连接到Argilla服务器的API密钥。默认为`None`,这意味着将使用
`ARGILLA_API_KEY`环境变量或默认值。
引发:
ImportError: 如果未安装`argilla`包。
ConnectionError: 如果连接到Argilla失败。
FileNotFoundError: 如果从Argilla检索`FeedbackDataset`失败。
示例:
>>> from langchain_community.llms import OpenAI
>>> from langchain_community.callbacks import ArgillaCallbackHandler
>>> argilla_callback = ArgillaCallbackHandler(
... dataset_name="my-dataset",
... workspace_name="my-workspace",
... api_url="http://localhost:6900",
... api_key="argilla.apikey",
... )
>>> llm = OpenAI(
... temperature=0,
... callbacks=[argilla_callback],
... verbose=True,
... openai_api_key="API_KEY_HERE",
... )
>>> llm.generate([
... "What is the best NLP-annotation tool out there? (no bias at all)",
... ])
"Argilla,毫无疑问。" """
REPO_URL: str = "https://github.com/argilla-io/argilla"
ISSUES_URL: str = f"{REPO_URL}/issues"
BLOG_URL: str = "https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html"
DEFAULT_API_URL: str = "http://localhost:6900"
[docs] def __init__(
self,
dataset_name: str,
workspace_name: Optional[str] = None,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> None:
"""初始化`ArgillaCallbackHandler`。
参数:
dataset_name:Argilla中`FeedbackDataset`的名称。请注意,它必须事先存在。如果您需要关于如何在Argilla中创建`FeedbackDataset`的帮助,请访问https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html。
workspace_name:Argilla中指定的`FeedbackDataset`所在的工作区的名称。默认为`None`,这意味着将使用默认工作区。
api_url:我们要使用的Argilla服务器的URL,以及`FeedbackDataset`所在的位置。默认为`None`,这意味着将使用`ARGILLA_API_URL`环境变量或默认值。
api_key:连接到Argilla服务器的API密钥。默认为`None`,这意味着将使用`ARGILLA_API_KEY`环境变量或默认值。
抛出:
ImportError:如果未安装`argilla`包。
ConnectionError:如果连接到Argilla失败。
FileNotFoundError:如果从Argilla检索`FeedbackDataset`失败。
"""
super().__init__()
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
try:
import argilla as rg
self.ARGILLA_VERSION = rg.__version__
except ImportError:
raise ImportError(
"To use the Argilla callback manager you need to have the `argilla` "
"Python package installed. Please install it with `pip install argilla`"
)
# Check whether the Argilla version is compatible
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
raise ImportError(
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
"upgrade `argilla` with `pip install --upgrade argilla`."
)
# Show a warning message if Argilla will assume the default values will be used
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
warnings.warn(
(
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
f" set, it will default to `{self.DEFAULT_API_URL}`, which is the"
" default API URL in Argilla Quickstart."
),
)
api_url = self.DEFAULT_API_URL
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
self.DEFAULT_API_KEY = (
"admin.apikey"
if parse(self.ARGILLA_VERSION) < parse("1.11.0")
else "owner.apikey"
)
warnings.warn(
(
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
f" set, it will default to `{self.DEFAULT_API_KEY}`, which is the"
" default API key in Argilla Quickstart."
),
)
api_key = self.DEFAULT_API_KEY
# Connect to Argilla with the provided credentials, if applicable
try:
rg.init(api_key=api_key, api_url=api_url)
except Exception as e:
raise ConnectionError(
f"Could not connect to Argilla with exception: '{e}'.\n"
"Please check your `api_key` and `api_url`, and make sure that "
"the Argilla server is up and running. If the problem persists "
f"please report it to {self.ISSUES_URL} as an `integration` issue."
) from e
# Set the Argilla variables
self.dataset_name = dataset_name
self.workspace_name = workspace_name or rg.get_workspace()
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
try:
extra_args = {}
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
warnings.warn(
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
" higher is recommended.",
UserWarning,
)
extra_args = {"with_records": False}
self.dataset = rg.FeedbackDataset.from_argilla(
name=self.dataset_name,
workspace=self.workspace_name,
**extra_args,
)
except Exception as e:
raise FileNotFoundError(
f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
f"\nPlease check that the dataset with name={self.dataset_name} in the"
f" workspace={self.workspace_name} exists in advance. If you need help"
" on how to create a `langchain`-compatible `FeedbackDataset` in"
f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
f" please report it to {self.ISSUES_URL} as an `integration` issue."
) from e
supported_fields = ["prompt", "response"]
if supported_fields != [field.name for field in self.dataset.fields]:
raise ValueError(
f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
f"{self.workspace_name} had fields that are not supported yet for the"
f"`langchain` integration. Supported fields are: {supported_fields},"
f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
" For more information on how to create a `langchain`-compatible"
f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
)
self.prompts: Dict[str, List[str]] = {}
warnings.warn(
(
"The `ArgillaCallbackHandler` is currently in beta and is subject to"
" change based on updates to `langchain`. Please report any issues to"
f" {self.ISSUES_URL} as an `integration` issue."
),
)
[docs] def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""当一个LLM启动时,将提示保存在内存中。"""
self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts})
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""当生成一个新的令牌时不执行任何操作。"""
pass
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""当LLM结束时,将日志记录到Argilla。"""
# 什么都不做 if there's a parent_run_id, since we will log the records when
# the chain ends
if kwargs["parent_run_id"]:
return
# Creates the records and adds them to the `FeedbackDataset`
prompts = self.prompts[str(kwargs["run_id"])]
for prompt, generations in zip(prompts, response.generations):
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": generation.text.strip(),
},
}
for generation in generations
]
)
# Pop current run from `self.runs`
self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""LLM 输出错误时不执行任何操作。"""
pass
[docs] def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""如果键`input`在`inputs`中,则使用`parent_run_id`或`run_id`将其保存在`self.prompts`中。这样做是为了避免在LLM启动时和链启动时重复记录相同的输入提示。
"""
if "input" in inputs:
self.prompts.update(
{
str(kwargs["parent_run_id"] or kwargs["run_id"]): (
inputs["input"]
if isinstance(inputs["input"], list)
else [inputs["input"]]
)
}
)
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""如果`parent_run_id`或`run_id`中的任一个在`self.prompts`中,那么将输出记录到Argilla,并从`self.prompts`中弹出该运行。如果输出是一个列表或不是一个列表,则行为会有所不同。
"""
if not any(
key in self.prompts
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
):
return
prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast(
List, self.prompts.get(str(kwargs["run_id"]), [])
)
for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, list):
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": output["text"].strip(),
},
}
for prompt, output in zip(prompts, chain_output_val)
]
)
else:
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": " ".join(prompts),
"response": chain_output_val.strip(),
},
}
]
)
# Pop current run from `self.runs`
if str(kwargs["parent_run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["parent_run_id"]))
if str(kwargs["run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""当LLM链输出错误时不执行任何操作。"""
pass
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""当代理执行特定动作时不执行任何操作。"""
pass
[docs] def on_text(self, text: str, **kwargs: Any) -> None:
"""什么都不做"""
pass
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""什么都不做"""
pass