Source code for langchain_community.callbacks.argilla_callback

import os
import warnings
from typing import Any, Dict, List, Optional, cast

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from packaging.version import parse


[docs]class ArgillaCallbackHandler(BaseCallbackHandler): """回调处理程序,用于登录到Argilla。 参数: dataset_name: Argilla中`FeedbackDataset`的名称。请注意,它必须 预先存在。如果您需要关于如何在Argilla中创建`FeedbackDataset`的帮助, 请访问 https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html。 workspace_name: Argilla中指定的工作区的名称, `FeedbackDataset`所在的地方。默认为`None`,这意味着将使用 默认工作区。 api_url: 我们要使用的Argilla服务器的URL,以及 `FeedbackDataset`所在的地方。默认为`None`,这意味着将使用 `ARGILLA_API_URL`环境变量或默认值。 api_key: 连接到Argilla服务器的API密钥。默认为`None`,这意味着将使用 `ARGILLA_API_KEY`环境变量或默认值。 引发: ImportError: 如果未安装`argilla`包。 ConnectionError: 如果连接到Argilla失败。 FileNotFoundError: 如果从Argilla检索`FeedbackDataset`失败。 示例: >>> from langchain_community.llms import OpenAI >>> from langchain_community.callbacks import ArgillaCallbackHandler >>> argilla_callback = ArgillaCallbackHandler( ... dataset_name="my-dataset", ... workspace_name="my-workspace", ... api_url="http://localhost:6900", ... api_key="argilla.apikey", ... ) >>> llm = OpenAI( ... temperature=0, ... callbacks=[argilla_callback], ... verbose=True, ... openai_api_key="API_KEY_HERE", ... ) >>> llm.generate([ ... "What is the best NLP-annotation tool out there? (no bias at all)", ... ]) "Argilla,毫无疑问。" """ REPO_URL: str = "https://github.com/argilla-io/argilla" ISSUES_URL: str = f"{REPO_URL}/issues" BLOG_URL: str = "https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html" DEFAULT_API_URL: str = "http://localhost:6900"
[docs] def __init__( self, dataset_name: str, workspace_name: Optional[str] = None, api_url: Optional[str] = None, api_key: Optional[str] = None, ) -> None: """初始化`ArgillaCallbackHandler`。 参数: dataset_name:Argilla中`FeedbackDataset`的名称。请注意,它必须事先存在。如果您需要关于如何在Argilla中创建`FeedbackDataset`的帮助,请访问https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html。 workspace_name:Argilla中指定的`FeedbackDataset`所在的工作区的名称。默认为`None`,这意味着将使用默认工作区。 api_url:我们要使用的Argilla服务器的URL,以及`FeedbackDataset`所在的位置。默认为`None`,这意味着将使用`ARGILLA_API_URL`环境变量或默认值。 api_key:连接到Argilla服务器的API密钥。默认为`None`,这意味着将使用`ARGILLA_API_KEY`环境变量或默认值。 抛出: ImportError:如果未安装`argilla`包。 ConnectionError:如果连接到Argilla失败。 FileNotFoundError:如果从Argilla检索`FeedbackDataset`失败。 """ super().__init__() # Import Argilla (not via `import_argilla` to keep hints in IDEs) try: import argilla as rg self.ARGILLA_VERSION = rg.__version__ except ImportError: raise ImportError( "To use the Argilla callback manager you need to have the `argilla` " "Python package installed. Please install it with `pip install argilla`" ) # Check whether the Argilla version is compatible if parse(self.ARGILLA_VERSION) < parse("1.8.0"): raise ImportError( f"The installed `argilla` version is {self.ARGILLA_VERSION} but " "`ArgillaCallbackHandler` requires at least version 1.8.0. Please " "upgrade `argilla` with `pip install --upgrade argilla`." ) # Show a warning message if Argilla will assume the default values will be used if api_url is None and os.getenv("ARGILLA_API_URL") is None: warnings.warn( ( "Since `api_url` is None, and the env var `ARGILLA_API_URL` is not" f" set, it will default to `{self.DEFAULT_API_URL}`, which is the" " default API URL in Argilla Quickstart." ), ) api_url = self.DEFAULT_API_URL if api_key is None and os.getenv("ARGILLA_API_KEY") is None: self.DEFAULT_API_KEY = ( "admin.apikey" if parse(self.ARGILLA_VERSION) < parse("1.11.0") else "owner.apikey" ) warnings.warn( ( "Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not" f" set, it will default to `{self.DEFAULT_API_KEY}`, which is the" " default API key in Argilla Quickstart." ), ) api_key = self.DEFAULT_API_KEY # Connect to Argilla with the provided credentials, if applicable try: rg.init(api_key=api_key, api_url=api_url) except Exception as e: raise ConnectionError( f"Could not connect to Argilla with exception: '{e}'.\n" "Please check your `api_key` and `api_url`, and make sure that " "the Argilla server is up and running. If the problem persists " f"please report it to {self.ISSUES_URL} as an `integration` issue." ) from e # Set the Argilla variables self.dataset_name = dataset_name self.workspace_name = workspace_name or rg.get_workspace() # Retrieve the `FeedbackDataset` from Argilla (without existing records) try: extra_args = {} if parse(self.ARGILLA_VERSION) < parse("1.14.0"): warnings.warn( f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or" " higher is recommended.", UserWarning, ) extra_args = {"with_records": False} self.dataset = rg.FeedbackDataset.from_argilla( name=self.dataset_name, workspace=self.workspace_name, **extra_args, ) except Exception as e: raise FileNotFoundError( f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`." f"\nPlease check that the dataset with name={self.dataset_name} in the" f" workspace={self.workspace_name} exists in advance. If you need help" " on how to create a `langchain`-compatible `FeedbackDataset` in" f" Argilla, please visit {self.BLOG_URL}. If the problem persists" f" please report it to {self.ISSUES_URL} as an `integration` issue." ) from e supported_fields = ["prompt", "response"] if supported_fields != [field.name for field in self.dataset.fields]: raise ValueError( f"`FeedbackDataset` with name={self.dataset_name} in the workspace=" f"{self.workspace_name} had fields that are not supported yet for the" f"`langchain` integration. Supported fields are: {supported_fields}," f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501 " For more information on how to create a `langchain`-compatible" f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}." ) self.prompts: Dict[str, List[str]] = {} warnings.warn( ( "The `ArgillaCallbackHandler` is currently in beta and is subject to" " change based on updates to `langchain`. Please report any issues to" f" {self.ISSUES_URL} as an `integration` issue." ), )
[docs] def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """当一个LLM启动时,将提示保存在内存中。""" self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts})
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """当生成一个新的令牌时不执行任何操作。""" pass
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """当LLM结束时,将日志记录到Argilla。""" # 什么都不做 if there's a parent_run_id, since we will log the records when # the chain ends if kwargs["parent_run_id"]: return # Creates the records and adds them to the `FeedbackDataset` prompts = self.prompts[str(kwargs["run_id"])] for prompt, generations in zip(prompts, response.generations): self.dataset.add_records( records=[ { "fields": { "prompt": prompt, "response": generation.text.strip(), }, } for generation in generations ] ) # Pop current run from `self.runs` self.prompts.pop(str(kwargs["run_id"])) if parse(self.ARGILLA_VERSION) < parse("1.14.0"): # Push the records to Argilla self.dataset.push_to_argilla()
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """LLM 输出错误时不执行任何操作。""" pass
[docs] def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """如果键`input`在`inputs`中,则使用`parent_run_id`或`run_id`将其保存在`self.prompts`中。这样做是为了避免在LLM启动时和链启动时重复记录相同的输入提示。 """ if "input" in inputs: self.prompts.update( { str(kwargs["parent_run_id"] or kwargs["run_id"]): ( inputs["input"] if isinstance(inputs["input"], list) else [inputs["input"]] ) } )
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """如果`parent_run_id`或`run_id`中的任一个在`self.prompts`中,那么将输出记录到Argilla,并从`self.prompts`中弹出该运行。如果输出是一个列表或不是一个列表,则行为会有所不同。 """ if not any( key in self.prompts for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])] ): return prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast( List, self.prompts.get(str(kwargs["run_id"]), []) ) for chain_output_key, chain_output_val in outputs.items(): if isinstance(chain_output_val, list): # Creates the records and adds them to the `FeedbackDataset` self.dataset.add_records( records=[ { "fields": { "prompt": prompt, "response": output["text"].strip(), }, } for prompt, output in zip(prompts, chain_output_val) ] ) else: # Creates the records and adds them to the `FeedbackDataset` self.dataset.add_records( records=[ { "fields": { "prompt": " ".join(prompts), "response": chain_output_val.strip(), }, } ] ) # Pop current run from `self.runs` if str(kwargs["parent_run_id"]) in self.prompts: self.prompts.pop(str(kwargs["parent_run_id"])) if str(kwargs["run_id"]) in self.prompts: self.prompts.pop(str(kwargs["run_id"])) if parse(self.ARGILLA_VERSION) < parse("1.14.0"): # Push the records to Argilla self.dataset.push_to_argilla()
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """当LLM链输出错误时不执行任何操作。""" pass
[docs] def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any, ) -> None: """工具启动时不执行任何操作。""" pass
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """当代理执行特定动作时不执行任何操作。""" pass
[docs] def on_tool_end( self, output: Any, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: """工具结束时不执行任何操作。""" pass
[docs] def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """当工具输出错误时不执行任何操作。""" pass
[docs] def on_text(self, text: str, **kwargs: Any) -> None: """什么都不做""" pass
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """什么都不做""" pass