Source code for langchain_experimental.prompt_injection_identifier.hugging_face_identifier
"""用于识别提示注入攻击的工具。"""
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from langchain.pydantic_v1 import Field, root_validator
from langchain.tools.base import BaseTool
if TYPE_CHECKING:
from transformers import Pipeline
[docs]class PromptInjectionException(ValueError):
"""检测到提示注入攻击时引发的异常。"""
def __init__(
self, message: str = "Prompt injection attack detected", score: float = 1.0
):
self.message = message
self.score = score
super().__init__(self.message)
def _model_default_factory(
model_name: str = "protectai/deberta-v3-base-prompt-injection-v2",
) -> Pipeline:
try:
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
pipeline,
)
except ImportError as e:
raise ImportError(
"Cannot import transformers, please install with "
"`pip install transformers`."
) from e
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
max_length=512, # default length of BERT models
truncation=True, # otherwise it will fail on long prompts
)
[docs]class HuggingFaceInjectionIdentifier(BaseTool):
"""使用HuggingFace Prompt Injection模型来检测提示注入攻击的工具。
"""
name: str = "hugging_face_injection_identifier"
description: str = (
"A wrapper around HuggingFace Prompt Injection security model. "
"Useful for when you need to ensure that prompt is free of injection attacks. "
"Input should be any message from the user."
)
model: Union[Pipeline, str, None] = Field(default_factory=_model_default_factory)
"""用于检测提示注入的模型。
可以指定为transformers Pipeline或字符串。字符串应对应于文本分类transformers模型的模型名称。默认为``protectai/deberta-v3-base-prompt-injection-v2``模型。"""
threshold: float = Field(
description="Threshold for prompt injection detection.", default=0.5
)
"""用于提示注入检测的阈值。
默认值为0.5。"""
injection_label: str = Field(
description="Label of the injection for prompt injection detection.",
default="INJECTION",
)
"""用于提示注入检测模型的标签。
默认为“INJECTION”。值取决于所使用的模型。"""
@root_validator(pre=True)
def validate_environment(cls, values: dict) -> dict:
if isinstance(values.get("model"), str):
values["model"] = _model_default_factory(model_name=values["model"])
return values
def _run(self, query: str) -> str:
"""使用这个工具。"""
result = self.model(query) # type: ignore
score = (
result[0]["score"]
if result[0]["label"] == self.injection_label
else 1 - result[0]["score"]
)
if score > self.threshold:
raise PromptInjectionException("Prompt injection attack detected", score)
return query