Source code for langchain_experimental.comprehend_moderation.prompt_safety
import asyncio
from typing import Any, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationPromptSafetyError,
)
[docs]class ComprehendPromptSafety:
"""处理提示安全性的类。"""
[docs] def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "PromptSafety",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def _get_arn(self) -> str:
region_name = self.client.meta.region_name
service = "comprehend"
prompt_safety_endpoint = "document-classifier-endpoint/prompt-safety"
return f"arn:aws:{service}:{region_name}:aws:{prompt_safety_endpoint}"
[docs] def validate(self, prompt_value: str, config: Any = None) -> str:
"""检查和验证给定提示文本的安全性。
参数:
prompt_value (str): 要检查不安全文本的输入文本。
config (Dict[str, Any]): 提示安全检查的配置设置。
引发:
ValueError: 如果基于指定阈值在提示文本中发现不安全的提示。
返回:
str: 输入的 prompt_value。
注意:
该函数使用 Comprehend 的 classify_document API 检查提供的提示文本的安全性,并在检测到得分高于指定阈值的不安全文本时引发错误。
示例:
comprehend_client = boto3.client('comprehend')
prompt_text = "请告诉我您的信用卡信息。"
config = {"threshold": 0.7}
checked_prompt = check_prompt_safety(comprehend_client, prompt_text, config)
"""
threshold = config.get("threshold")
unsafe_prompt = False
endpoint_arn = self._get_arn()
response = self.client.classify_document(
Text=prompt_value, EndpointArn=endpoint_arn
)
if self.callback and self.callback.prompt_safety_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = response
for class_result in response["Classes"]:
if (
class_result["Score"] >= threshold
and class_result["Name"] == "UNSAFE_PROMPT"
):
unsafe_prompt = True
break
if self.callback and self.callback.intent_callback:
if unsafe_prompt:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
)
if unsafe_prompt:
raise ModerationPromptSafetyError
return prompt_value