Source code for langchain_community.embeddings.tensorflow_hub

from typing import Any, List

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra

DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"


[docs]class TensorflowHubEmbeddings(BaseModel, Embeddings): """TensorflowHub嵌入模型。 要使用,您应该安装``tensorflow_text`` python包。 示例: .. code-block:: python from langchain_community.embeddings import TensorflowHubEmbeddings url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" tf = TensorflowHubEmbeddings(model_url=url)""" embed: Any #: :meta private: model_url: str = DEFAULT_MODEL_URL """要使用的模型名称。""" def __init__(self, **kwargs: Any): """初始化tensorflow_hub和tensorflow_text。""" super().__init__(**kwargs) try: import tensorflow_hub except ImportError: raise ImportError( "Could not import tensorflow-hub python package. " "Please install it with `pip install tensorflow-hub``." ) try: import tensorflow_text # noqa except ImportError: raise ImportError( "Could not import tensorflow_text python package. " "Please install it with `pip install tensorflow_text``." ) self.embed = tensorflow_hub.load(self.model_url) class Config: """此pydantic对象的配置。""" extra = Extra.forbid
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """使用TensorflowHub嵌入模型计算文档嵌入。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ texts = list(map(lambda x: x.replace("\n", " "), texts)) embeddings = self.embed(texts).numpy() return embeddings.tolist()
[docs] def embed_query(self, text: str) -> List[float]: """使用TensorflowHub嵌入模型计算查询嵌入。 参数: text:要嵌入的文本。 返回: 文本的嵌入。 """ text = text.replace("\n", " ") embedding = self.embed([text]).numpy()[0] return embedding.tolist()