Source code for langchain_community.embeddings.nemo

from __future__ import annotations

import asyncio
import json
from typing import Any, Dict, List, Optional

import aiohttp
import requests
from langchain_core._api.deprecation import deprecated
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator


[docs]def is_endpoint_live(url: str, headers: Optional[dict], payload: Any) -> bool: """检查通过向指定的URL发送GET请求来确定端点是否处于活动状态。 参数: url (str): 要检查的端点的URL。 返回: bool: 如果端点处于活动状态(状态码为200),则为True,否则为False。 引发: Exception: 如果端点返回非成功状态码或查询端点时出现错误。 """ try: response = requests.request("POST", url, headers=headers, data=payload) # Check if the status code is 200 (OK) if response.status_code == 200: return True else: # Raise an exception if the status code is not 200 raise Exception( f"Endpoint returned a non-successful status code: " f"{response.status_code}" ) except requests.exceptions.RequestException as e: # Handle any exceptions (e.g., connection errors) raise Exception(f"Error querying the endpoint: {e}")
[docs]@deprecated( since="0.0.37", removal="0.2.0", message=( "Directly instantiating a NeMoEmbeddings from langchain-community is " "deprecated. Please use langchain-nvidia-ai-endpoints NVIDIAEmbeddings " "interface." ), ) class NeMoEmbeddings(BaseModel, Embeddings): """NeMo嵌入模型。""" batch_size: int = 16 model: str = "NV-Embed-QA-003" api_endpoint_url: str = "http://localhost:8088/v1/embeddings" @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证终点是否存活,使用提供的值。""" url = values["api_endpoint_url"] model = values["model"] # Optional: A minimal test payload and headers required by the endpoint headers = {"Content-Type": "application/json"} payload = json.dumps( {"input": "Hello World", "model": model, "input_type": "query"} ) is_endpoint_live(url, headers, payload) return values async def _aembedding_func( self, session: Any, text: str, input_type: str ) -> List[float]: """异步调用嵌入端点。 参数: text:要嵌入的文本。 返回: 文本的嵌入结果。 """ headers = {"Content-Type": "application/json"} async with session.post( self.api_endpoint_url, json={"input": text, "model": self.model, "input_type": input_type}, headers=headers, ) as response: response.raise_for_status() answer = await response.text() answer = json.loads(answer) return answer["data"][0]["embedding"] def _embedding_func(self, text: str, input_type: str) -> List[float]: """调用Cohere的嵌入端点。 参数: text:要嵌入的文本。 返回: 文本的嵌入。 """ payload = json.dumps( {"input": text, "model": self.model, "input_type": input_type} ) headers = {"Content-Type": "application/json"} response = requests.request( "POST", self.api_endpoint_url, headers=headers, data=payload ) response_json = json.loads(response.text) embedding = response_json["data"][0]["embedding"] return embedding
[docs] def embed_documents(self, documents: List[str]) -> List[List[float]]: """嵌入文档文本列表。 参数: texts:要嵌入的文本列表。 返回: 嵌入列表,每个文本对应一个嵌入。 """ return [self._embedding_func(text, input_type="passage") for text in documents]
[docs] def embed_query(self, text: str) -> List[float]: return self._embedding_func(text, input_type="query")
[docs] async def aembed_query(self, text: str) -> List[float]: """调用NeMo的嵌入端点异步进行嵌入查询文本。 参数: text:要嵌入的文本。 返回: 文本的嵌入。 """ async with aiohttp.ClientSession() as session: embedding = await self._aembedding_func(session, text, "passage") return embedding
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """调用NeMo的嵌入端点异步进行嵌入搜索文档。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ embeddings = [] async with aiohttp.ClientSession() as session: for batch in range(0, len(texts), self.batch_size): text_batch = texts[batch : batch + self.batch_size] for text in text_batch: # Create tasks for all texts in the batch tasks = [ self._aembedding_func(session, text, "passage") for text in text_batch ] # Run all tasks concurrently batch_results = await asyncio.gather(*tasks) # Extend the embeddings list with results from this batch embeddings.extend(batch_results) return embeddings