Skip to content

Llm rails

LLMRailsEmbedding #

Bases: BaseEmbedding

LLMRails嵌入模型。

这个类提供了一个接口,用于使用部署在LLMRails集群中的模型生成嵌入。它需要模型在集群中的model_id和您可以从https://console.llmrails.com/api-keys获取的api密钥。

Source code in llama_index/embeddings/llm_rails/base.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class LLMRailsEmbedding(BaseEmbedding):
    """LLMRails嵌入模型。

这个类提供了一个接口,用于使用部署在LLMRails集群中的模型生成嵌入。它需要模型在集群中的model_id和您可以从https://console.llmrails.com/api-keys获取的api密钥。"""

    model_id: str
    api_key: str
    session: requests.Session

    @classmethod
    def class_name(self) -> str:
        return "LLMRailsEmbedding"

    def __init__(
        self,
        api_key: str,
        model_id: str = "embedding-english-v1",  # or embedding-multi-v1
        **kwargs: Any,
    ):
        retry = Retry(
            total=3,
            connect=3,
            read=2,
            allowed_methods=["POST"],
            backoff_factor=2,
            status_forcelist=[502, 503, 504],
        )
        session = requests.Session()
        session.mount("https://api.llmrails.com", HTTPAdapter(max_retries=retry))
        session.headers = {"X-API-KEY": api_key}
        super().__init__(model_id=model_id, api_key=api_key, session=session, **kwargs)

    def _get_embedding(self, text: str) -> List[float]:
        """为单个查询文本生成嵌入。


Args:
    text(str):要生成嵌入的查询文本。

Returns:
    List[float]:输入查询文本的嵌入。
"""
        try:
            response = self.session.post(
                "https://api.llmrails.com/v1/embeddings",
                json={"input": [text], "model": self.model_id},
            )

            response.raise_for_status()
            return response.json()["data"][0]["embedding"]

        except requests.exceptions.HTTPError as e:
            logger.error(f"Error while embedding text {e}.")
            raise ValueError(f"Unable to embed given text {e}")

    async def _aget_embedding(self, text: str) -> List[float]:
        """为单个查询文本生成嵌入。


Args:
    text(str):要生成嵌入的查询文本。

Returns:
    List[float]:输入查询文本的嵌入。
"""
        try:
            import httpx
        except ImportError:
            raise ImportError(
                "The httpx library is required to use the async version of "
                "this function. Install it with `pip install httpx`."
            )

        try:
            async with httpx.AsyncClient() as client:
                response = await client.post(
                    "https://api.llmrails.com/v1/embeddings",
                    headers={"X-API-KEY": self.api_key},
                    json={"input": [text], "model": self.model_id},
                )

                response.raise_for_status()

            return response.json()["data"][0]["embedding"]

        except httpx._exceptions.HTTPError as e:
            logger.error(f"Error while embedding text {e}.")
            raise ValueError(f"Unable to embed given text {e}")

    def _get_text_embedding(self, text: str) -> List[float]:
        return self._get_embedding(text)

    def _get_query_embedding(self, query: str) -> List[float]:
        return self._get_embedding(query)

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return await self._aget_embedding(query)

    async def _aget_text_embedding(self, query: str) -> List[float]:
        return await self._aget_embedding(query)