Source code for langchain_experimental.comprehend_moderation.toxicity
import asyncio
import importlib
from typing import Any, List, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationToxicityError,
)
[docs]class ComprehendToxicity:
"""处理毒性内容的类。"""
[docs] def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "Toxicity",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def _toxicity_init_validate(self, max_size: int) -> Any:
"""验证和初始化毒性处理配置。
参数:
max_size(int):配置对象中定义的最大句子大小。
引发:
异常:如果最大句子大小超过5KB的限制。
注意:
此函数确保如果尚未存在,则下载NLTK punkt分词器。
返回:
无
"""
if max_size > 1024 * 5:
raise Exception("The sentence length should not exceed 5KB.")
try:
nltk = importlib.import_module("nltk")
nltk.data.find("tokenizers/punkt")
return nltk
except ImportError:
raise ModuleNotFoundError(
"Could not import nltk python package. "
"Please install it with `pip install nltk`."
)
except LookupError:
nltk.download("punkt")
def _split_paragraph(
self, prompt_value: str, max_size: int = 1024 * 4
) -> List[List[str]]:
"""将段落分割成句子的块,遵守最大大小限制。
参数:
paragraph(str):要分割成块的输入段落。
max_size(int,可选):每个块的最大大小限制(以字节为单位)。默认为1024。
返回:
List[List[str]]:块的列表,其中每个块都是句子的列表。
注意:
此函数使用“toxicity_init_validate”函数基于服务限制验证最大句子大小。它使用NLTK句子分词器将段落分割成句子。
示例:
paragraph = "这是一个示例段落。它包含多个句子。..."
chunks = split_paragraph(paragraph, max_size=2048)
"""
# validate max. sentence size based on Service limits
nltk = self._toxicity_init_validate(max_size)
sentences = nltk.sent_tokenize(prompt_value)
chunks = list() # type: ignore
current_chunk = list() # type: ignore
current_size = 0
for sentence in sentences:
sentence_size = len(sentence.encode("utf-8"))
# If adding a new sentence exceeds max_size
# or current_chunk has 10 sentences, start a new chunk
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
if current_chunk: # Avoid appending empty chunks
chunks.append(current_chunk)
current_chunk = []
current_size = 0
current_chunk.append(sentence)
current_size += sentence_size
# Add any remaining sentences
if current_chunk:
chunks.append(current_chunk)
return chunks
[docs] def validate(self, prompt_value: str, config: Any = None) -> str:
"""使用AWS Comprehend服务检查给定文本提示的毒性,并根据配置应用操作。
参数:
prompt_value(str):要检查毒性的文本内容。
config(Dict[str,Any]):毒性检查和操作的配置。
返回:
str:如果允许或未发现毒性,则返回原始prompt_value。
引发:
ValueError:如果提示包含有毒标签并且根据配置无法处理。
"""
chunks = self._split_paragraph(prompt_value=prompt_value)
for sentence_list in chunks:
segments = [{"Text": sentence} for sentence in sentence_list]
response = self.client.detect_toxic_content(
TextSegments=segments, LanguageCode="en"
)
if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_input"] = segments # type: ignore
self.moderation_beacon["moderation_output"] = response
toxicity_found = False
threshold = config.get("threshold")
toxicity_labels = config.get("labels")
if not toxicity_labels:
for item in response["ResultList"]:
for label in item["Labels"]:
if label["Score"] >= threshold:
toxicity_found = True
break
else:
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label["Name"] in toxicity_labels
and label["Score"] >= threshold
):
toxicity_found = True
break
if self.callback and self.callback.toxicity_callback:
if toxicity_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
if toxicity_found:
raise ModerationToxicityError
return prompt_value