Skip to content

Gradient

GradientEmbedding #

Bases: BaseEmbedding

GradientAI嵌入模型。

该类提供了使用在Gradient AI中部署的模型生成嵌入的接口。在初始化时,需要提供集群中部署的模型的model_id。

注意

需要在PYTHONPATH中可用gradientai包。可以使用pip install gradientai进行安装。

Source code in llama_index/embeddings/gradient/base.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class GradientEmbedding(BaseEmbedding):
    """GradientAI嵌入模型。

    该类提供了使用在Gradient AI中部署的模型生成嵌入的接口。在初始化时,需要提供集群中部署的模型的model_id。

    注意:
        需要在PYTHONPATH中可用`gradientai`包。可以使用`pip install gradientai`进行安装。"""

    embed_batch_size: int = Field(default=GRADIENT_EMBED_BATCH_SIZE, gt=0)

    _gradient: Any = PrivateAttr()
    _model: Any = PrivateAttr()

    @classmethod
    def class_name(cls) -> str:
        return "GradientEmbedding"

    def __init__(
        self,
        *,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        gradient_model_slug: str,
        gradient_access_token: Optional[str] = None,
        gradient_workspace_id: Optional[str] = None,
        gradient_host: Optional[str] = None,
        **kwargs: Any,
    ):
        """初始化GradientEmbedding类。

在初始化过程中,导入了`gradientai`包。使用访问令牌、工作区ID和模型的slug,从Gradient AI获取模型并准备好使用。

Args:
    embed_batch_size (int, optional): 用于生成嵌入的批处理大小。默认为10,必须大于0且小于等于100。
    gradient_model_slug (str): Gradient AI帐户中模型的模型slug。
    gradient_access_token (str, optional): Gradient AI帐户的访问令牌,如果为`None`,则从环境变量`GRADIENT_ACCESS_TOKEN`中读取。
    gradient_workspace_id (str, optional): Gradient AI帐户的工作区ID,如果为`None`,则从环境变量`GRADIENT_WORKSPACE_ID`中读取。
    gradient_host (str, optional): Gradient AI API的主机。默认为None,表示使用默认主机。

引发:
    ImportError: 如果PYTHONPATH中找不到`gradientai`包。
    ValueError: 如果无法从Gradient AI获取模型。
"""
        if embed_batch_size <= 0:
            raise ValueError(f"Embed batch size {embed_batch_size}  must be > 0.")

        self._gradient = gradientai.Gradient(
            access_token=gradient_access_token,
            workspace_id=gradient_workspace_id,
            host=gradient_host,
        )

        try:
            self._model = self._gradient.get_embeddings_model(slug=gradient_model_slug)
        except gradientai.openapi.client.exceptions.UnauthorizedException as e:
            logger.error(f"Error while loading model {gradient_model_slug}.")
            self._gradient.close()
            raise ValueError("Unable to fetch the requested embeddings model") from e

        super().__init__(
            embed_batch_size=embed_batch_size, model_name=gradient_model_slug, **kwargs
        )

    async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
        """
        在异步方式下嵌入输入的文本序列。
        """
        inputs = [{"input": text} for text in texts]

        result = await self._model.aembed(inputs=inputs).embeddings

        return [e.embedding for e in result]

    def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
        """
        嵌入输入的文本序列。
        """
        inputs = [{"input": text} for text in texts]

        result = self._model.embed(inputs=inputs).embeddings

        return [e.embedding for e in result]

    def _get_text_embedding(self, text: str) -> Embedding:
        """使用单个文本输入的_get_text_embeddings()的别名。"""
        return self._get_text_embeddings([text])[0]

    async def _aget_text_embedding(self, text: str) -> Embedding:
        """使用单个文本输入的_aget_text_embeddings()的别名。"""
        embedding = await self._aget_text_embeddings([text])
        return embedding[0]

    async def _aget_query_embedding(self, query: str) -> Embedding:
        embedding = await self._aget_text_embeddings(
            [f"{QUERY_INSTRUCTION_FOR_RETRIEVAL} {query}"]
        )
        return embedding[0]

    def _get_query_embedding(self, query: str) -> Embedding:
        return self._get_text_embeddings(
            [f"{QUERY_INSTRUCTION_FOR_RETRIEVAL} {query}"]
        )[0]