Skip to content

Gemini

GeminiEmbedding #

Bases: BaseEmbedding

谷歌 Gemini 嵌入。

Parameters:

Name Type Description Default
model_name str

嵌入模型。 默认为 "models/embedding-001"。

'models/embedding-001'
api_key Optional[str]

访问模型的 API 密钥。默认为 None。

None
api_base Optional[str]

访问模型的 API 基础。默认为 Official Base。

None
transport Optional[str]

访问模型的传输方式。

None
Source code in llama_index/embeddings/gemini/base.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class GeminiEmbedding(BaseEmbedding):
    """谷歌 Gemini 嵌入。

    Args:
        model_name (str): 嵌入模型。
            默认为 "models/embedding-001"。

        api_key (Optional[str]): 访问模型的 API 密钥。默认为 None。
        api_base (Optional[str]): 访问模型的 API 基础。默认为 Official Base。
        transport (Optional[str]): 访问模型的传输方式。"""

    _model: Any = PrivateAttr()
    title: Optional[str] = Field(
        default="",
        description="Title is only applicable for retrieval_document tasks, and is used to represent a document title. For other tasks, title is invalid.",
    )
    task_type: Optional[str] = Field(
        default="retrieval_document",
        description="The task for embedding model.",
    )

    def __init__(
        self,
        model_name: str = "models/embedding-001",
        task_type: Optional[str] = "retrieval_document",
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
        transport: Optional[str] = None,
        title: Optional[str] = None,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        callback_manager: Optional[CallbackManager] = None,
        **kwargs: Any,
    ):
        # API keys are optional. The API can be authorised via OAuth (detected
        # environmentally) or by the GOOGLE_API_KEY environment variable.
        config_params: Dict[str, Any] = {
            "api_key": api_key or os.getenv("GOOGLE_API_KEY"),
        }
        if api_base:
            config_params["client_options"] = {"api_endpoint": api_base}
        if transport:
            config_params["transport"] = transport
        # transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
        gemini.configure(**config_params)
        self._model = gemini

        super().__init__(
            model_name=model_name,
            embed_batch_size=embed_batch_size,
            callback_manager=callback_manager,
            **kwargs,
        )
        self.title = title
        self.task_type = task_type

    @classmethod
    def class_name(cls) -> str:
        return "GeminiEmbedding"

    def _get_query_embedding(self, query: str) -> List[float]:
        """获取查询嵌入。"""
        return self._model.embed_content(
            model=self.model_name,
            content=query,
            title=self.title,
            task_type=self.task_type,
        )["embedding"]

    def _get_text_embedding(self, text: str) -> List[float]:
        """获取文本嵌入。"""
        return self._model.embed_content(
            model=self.model_name,
            content=text,
            title=self.title,
            task_type=self.task_type,
        )["embedding"]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """获取文本嵌入。"""
        return [
            self._model.embed_content(
                model=self.model_name,
                content=text,
                title=self.title,
                task_type=self.task_type,
            )["embedding"]
            for text in texts
        ]

    ### Async methods ###
    # need to wait async calls from Gemini side to be implemented.
    # Issue: https://github.com/google/generative-ai-python/issues/125
    async def _aget_query_embedding(self, query: str) -> List[float]:
        """_get_query_embedding的异步版本。"""
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """异步获取文本嵌入。"""
        return self._get_text_embedding(text)

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """异步获取文本嵌入。"""
        return self._get_text_embeddings(texts)