Source code for langchain_community.embeddings.laser

from typing import Any, Dict, List, Optional

import numpy as np
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator

LASER_MULTILINGUAL_MODEL: str = "laser2"


[docs]class LaserEmbeddings(BaseModel, Embeddings): """LASER语言无关句子表示。 LASER是Meta AI研究团队开发的Python库, 用于为147种以上的语言创建多语言句子嵌入,截至2024年2月25日。 查看更多文档: * https://github.com/facebookresearch/LASER/ * https://github.com/facebookresearch/LASER/tree/main/laser_encoders * https://arxiv.org/abs/2205.12654 要使用这个类,您必须安装`laser_encoders` Python包。 `pip install laser_encoders` 示例: from laser_encoders import LaserEncoderPipeline encoder = LaserEncoderPipeline(lang="eng_Latn") embeddings = encoder.encode_sentences(["Hello", "World"]) """ lang: Optional[str] """您想使用的语言或语言代码 如果为空,此实现将默认使用一个多语言的早期LASER编码器模型(称为laser2) 在以下链接中找到支持的语言列表 https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200""" _encoder_pipeline: Any # : :meta private: class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证已安装laser_encoders。""" try: from laser_encoders import LaserEncoderPipeline lang = values.get("lang") if lang: encoder_pipeline = LaserEncoderPipeline(lang=lang) else: encoder_pipeline = LaserEncoderPipeline(laser=LASER_MULTILINGUAL_MODEL) values["_encoder_pipeline"] = encoder_pipeline except ImportError as e: raise ImportError( "Could not import 'laser_encoders' Python package. " "Please install it with `pip install laser_encoders`." ) from e return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """为文档使用LASER生成嵌入。 参数: texts: 要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ embeddings: np.ndarray embeddings = self._encoder_pipeline.encode_sentences(texts) return embeddings.tolist()
[docs] def embed_query(self, text: str) -> List[float]: """使用LASER生成单个查询文本嵌入。 参数: text: 要嵌入的文本。 返回: 文本的嵌入。 """ query_embeddings: np.ndarray query_embeddings = self._encoder_pipeline.encode_sentences([text]) return query_embeddings.tolist()[0]