Skip to content

Google

GeminiEmbedding #

Bases: BaseEmbedding

谷歌 Gemini 嵌入。

Parameters:

Name Type Description Default
model_name str

嵌入模型的名称。 默认为 "models/embedding-001"。

'models/embedding-001'
api_key Optional[str]

访问模型的 API 密钥。默认为 None。

None
Source code in llama_index/embeddings/google/gemini.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class GeminiEmbedding(BaseEmbedding):
    """谷歌 Gemini 嵌入。 

    Args:
        model_name (str): 嵌入模型的名称。
            默认为 "models/embedding-001"。

        api_key (Optional[str]): 访问模型的 API 密钥。默认为 None。"""

    _model: Any = PrivateAttr()
    title: Optional[str] = Field(
        default="",
        description="Title is only applicable for retrieval_document tasks, and is used to represent a document title. For other tasks, title is invalid.",
    )
    task_type: Optional[str] = Field(
        default="retrieval_document",
        description="The task for embedding model.",
    )

    def __init__(
        self,
        model_name: str = "models/embedding-001",
        task_type: Optional[str] = "retrieval_document",
        api_key: Optional[str] = None,
        title: Optional[str] = None,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        callback_manager: Optional[CallbackManager] = None,
        **kwargs: Any,
    ):
        gemini.configure(api_key=api_key)
        self._model = gemini

        super().__init__(
            model_name=model_name,
            embed_batch_size=embed_batch_size,
            callback_manager=callback_manager,
            **kwargs,
        )
        self.title = title
        self.task_type = task_type

    @classmethod
    def class_name(cls) -> str:
        return "GeminiEmbedding"

    def _get_query_embedding(self, query: str) -> List[float]:
        """获取查询嵌入。"""
        return self._model.embed_content(
            model=self.model_name,
            content=query,
            title=self.title,
            task_type=self.task_type,
        )["embedding"]

    def _get_text_embedding(self, text: str) -> List[float]:
        """获取文本嵌入。"""
        return self._model.embed_content(
            model=self.model_name,
            content=text,
            title=self.title,
            task_type=self.task_type,
        )["embedding"]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """获取文本嵌入。"""
        return [
            self._model.embed_content(
                model=self.model_name,
                content=text,
                title=self.title,
                task_type=self.task_type,
            )["embedding"]
            for text in texts
        ]

    ### Async methods ###
    # need to wait async calls from Gemini side to be implemented.
    # Issue: https://github.com/google/generative-ai-python/issues/125
    async def _aget_query_embedding(self, query: str) -> List[float]:
        """_get_query_embedding的异步版本。"""
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """异步获取文本嵌入。"""
        return self._get_text_embedding(text)

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """异步获取文本嵌入。"""
        return self._get_text_embeddings(texts)

GooglePaLMEmbedding #

Bases: BaseEmbedding

用于Google PaLM嵌入的类。

Parameters:

Name Type Description Default
model_name str

用于嵌入的模型。 默认为 "models/embedding-gecko-001"。

'models/embedding-gecko-001'
api_key Optional[str]

访问模型的API密钥。默认为None。

None
Source code in llama_index/embeddings/google/palm.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class GooglePaLMEmbedding(BaseEmbedding):
    """用于Google PaLM嵌入的类。

    Args:
        model_name (str): 用于嵌入的模型。
            默认为 "models/embedding-gecko-001"。

        api_key (Optional[str]): 访问模型的API密钥。默认为None。"""

    _model: Any = PrivateAttr()

    def __init__(
        self,
        model_name: str = "models/embedding-gecko-001",
        api_key: Optional[str] = None,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        callback_manager: Optional[CallbackManager] = None,
        **kwargs: Any,
    ):
        palm.configure(api_key=api_key)
        self._model = palm

        super().__init__(
            model_name=model_name,
            embed_batch_size=embed_batch_size,
            callback_manager=callback_manager,
            **kwargs,
        )

    @classmethod
    def class_name(cls) -> str:
        return "PaLMEmbedding"

    def _get_query_embedding(self, query: str) -> List[float]:
        """获取查询嵌入。"""
        return self._model.generate_embeddings(model=self.model_name, text=query)[
            "embedding"
        ]

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """_get_query_embedding的异步版本。"""
        return await self._model.aget_embedding(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        """获取文本嵌入。"""
        return self._model.generate_embeddings(model=self.model_name, text=text)[
            "embedding"
        ]

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """异步获取文本嵌入。"""
        return self._model._get_text_embedding(text)

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """获取文本嵌入。"""
        return self._model.generate_embeddings(model=self.model_name, text=texts)[
            "embedding"
        ]

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """异步获取文本嵌入。"""
        return await self._model._get_embeddings(texts)

GoogleUnivSentEncoderEmbedding #

Bases: BaseEmbedding

Source code in llama_index/embeddings/google/univ_sent_encoder.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class GoogleUnivSentEncoderEmbedding(BaseEmbedding):
    _model: Any = PrivateAttr()

    def __init__(
        self,
        handle: Optional[str] = None,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        callback_manager: Optional[CallbackManager] = None,
    ):
        """初始化参数。"""
        handle = handle or DEFAULT_HANDLE
        try:
            import tensorflow_hub as hub

            model = hub.load(handle)
        except ImportError:
            raise ImportError(
                "Please install tensorflow_hub: `pip install tensorflow_hub`"
            )

        self._model = model
        super().__init__(
            embed_batch_size=embed_batch_size,
            callback_manager=callback_manager,
            model_name=handle,
        )

    @classmethod
    def class_name(cls) -> str:
        return "GoogleUnivSentEncoderEmbedding"

    def _get_query_embedding(self, query: str) -> List[float]:
        """获取查询嵌入。"""
        return self._get_embedding(query)

    # TODO: use proper async methods
    async def _aget_text_embedding(self, query: str) -> List[float]:
        """获取文本嵌入。"""
        return self._get_embedding(query)

    # TODO: user proper async methods
    async def _aget_query_embedding(self, query: str) -> List[float]:
        """获取查询嵌入。"""
        return self._get_embedding(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        """获取文本嵌入。"""
        return self._get_embedding(text)

    def _get_embedding(self, text: str) -> List[float]:
        vectors = self._model([text]).numpy().tolist()
        return vectors[0]