Source code for langchain.chains.moderation
"""通过一个中继端传递输入。"""
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import check_package_version, get_from_dict_or_env
from langchain.chains.base import Chain
[docs]class OpenAIModerationChain(Chain):
"""通过审核端点传递输入。
要使用,您应该已安装``openai`` python包,并且
环境变量``OPENAI_API_KEY``设置为您的API密钥。
任何可以传递给openai.create调用的有效参数都可以传递
即使在此类中没有明确保存。
示例:
.. code-block:: python
from langchain.chains import OpenAIModerationChain
moderation = OpenAIModerationChain()"""
client: Any #: :meta private:
async_client: Any #: :meta private:
model_name: Optional[str] = None
"""要使用的调节模型名称。"""
error: bool = False
"""发现坏内容时是否报错。"""
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
openai_api_key: Optional[str] = None
openai_organization: Optional[str] = None
_openai_pre_1_0: bool = Field(default=None)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥和Python包。"""
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
openai_organization = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
default="",
)
try:
import openai
openai.api_key = openai_api_key
if openai_organization:
openai.organization = openai_organization
values["_openai_pre_1_0"] = False
try:
check_package_version("openai", gte_version="1.0")
except ValueError:
values["_openai_pre_1_0"] = True
if values["_openai_pre_1_0"]:
values["client"] = openai.Moderation
else:
values["client"] = openai.OpenAI()
values["async_client"] = openai.AsyncOpenAI()
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
return values
@property
def input_keys(self) -> List[str]:
"""期望输入键。
:元数据 私有:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""返回输出键。
:元数据 私有:
"""
return [self.output_key]
def _moderate(self, text: str, results: Any) -> str:
if self._openai_pre_1_0:
condition = results["flagged"]
else:
condition = results.flagged
if condition:
error_str = "Text was found that violates OpenAI's content policy."
if self.error:
raise ValueError(error_str)
else:
return error_str
return text
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
if self._openai_pre_1_0:
results = self.client.create(text)
output = self._moderate(text, results["results"][0])
else:
results = self.client.moderations.create(input=text)
output = self._moderate(text, results.results[0])
return {self.output_key: output}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self._openai_pre_1_0:
return await super()._acall(inputs, run_manager=run_manager)
text = inputs[self.input_key]
results = await self.async_client.moderations.create(input=text)
output = self._moderate(text, results.results[0])
return {self.output_key: output}