Source code for langchain_community.storage.astradb

from __future__ import annotations

import base64
from abc import ABC, abstractmethod
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncIterator,
    Generic,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
)

from langchain_core._api.deprecation import deprecated
from langchain_core.stores import BaseStore, ByteStore

from langchain_community.utilities.astradb import (
    SetupMode,
    _AstraDBCollectionEnvironment,
)

if TYPE_CHECKING:
    from astrapy.db import AstraDB, AsyncAstraDB

V = TypeVar("V")


[docs]class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): """用于DataStax AstraDB数据存储的基类。"""
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs) self.collection = self.astra_env.collection self.async_collection = self.astra_env.async_collection
[docs] @abstractmethod def decode_value(self, value: Any) -> Optional[V]: """从Astra DB解码值"""
[docs] @abstractmethod def encode_value(self, value: Optional[V]) -> Any: """为Astra DB编码数值"""
[docs] def mget(self, keys: Sequence[str]) -> List[Optional[V]]: self.astra_env.ensure_db_setup() docs_dict = {} for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}): docs_dict[doc["_id"]] = doc.get("value") return [self.decode_value(docs_dict.get(key)) for key in keys]
[docs] async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: await self.astra_env.aensure_db_setup() docs_dict = {} async for doc in self.async_collection.paginated_find( filter={"_id": {"$in": list(keys)}} ): docs_dict[doc["_id"]] = doc.get("value") return [self.decode_value(docs_dict.get(key)) for key in keys]
[docs] def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: self.astra_env.ensure_db_setup() for k, v in key_value_pairs: self.collection.upsert({"_id": k, "value": self.encode_value(v)})
[docs] async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: await self.astra_env.aensure_db_setup() for k, v in key_value_pairs: await self.async_collection.upsert( {"_id": k, "value": self.encode_value(v)} )
[docs] def mdelete(self, keys: Sequence[str]) -> None: self.astra_env.ensure_db_setup() self.collection.delete_many(filter={"_id": {"$in": list(keys)}})
[docs] async def amdelete(self, keys: Sequence[str]) -> None: await self.astra_env.aensure_db_setup() await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})
[docs] def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: self.astra_env.ensure_db_setup() docs = self.collection.paginated_find() for doc in docs: key = doc["_id"] if not prefix or key.startswith(prefix): yield key
[docs] async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: await self.astra_env.aensure_db_setup() async for doc in self.async_collection.paginated_find(): key = doc["_id"] if not prefix or key.startswith(prefix): yield key
[docs]@deprecated( since="0.0.22", removal="0.3.0", alternative_import="langchain_astradb.AstraDBStore", ) class AstraDBStore(AstraDBBaseStore[Any]):
[docs] def __init__( self, collection_name: str, token: Optional[str] = None, api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, namespace: Optional[str] = None, *, async_astra_db_client: Optional[AsyncAstraDB] = None, pre_delete_collection: bool = False, setup_mode: SetupMode = SetupMode.SYNC, ) -> None: """使用DataStax AstraDB作为基础存储的BaseStore实现。 值的类型可以是任何可以通过json.dumps序列化的类型。 可用于存储与CacheBackedEmbeddings一起使用的嵌入。 AstraDB集合中的文档将具有以下格式 .. code-block:: json { "_id": "<key>", "value": <value> } 参数: collection_name: 要创建/使用的Astra DB集合的名称。 token: 用于Astra DB使用的API令牌。 api_endpoint: API端点的完整URL, 例如 `https://<DB-ID>-us-east1.apps.astra.datastax.com`。 astra_db_client: *token+api_endpoint的替代*, 您可以传递一个已创建的'astrapy.db.AstraDB'实例。 async_astra_db_client: *token+api_endpoint的替代*, 您可以传递一个已创建的'astrapy.db.AsyncAstraDB'实例。 namespace: 创建集合的命名空间(又名键空间)。 默认为数据库的“默认命名空间”。 setup_mode: 用于创建Astra DB集合的模式(SYNC、ASYNC或OFF)。 pre_delete_collection: 是否在创建集合之前删除集合。 如果为False且集合已经存在,则将直接使用该集合。 """ # Constructor doc is not inherited so we have to override it. super().__init__( collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, namespace=namespace, setup_mode=setup_mode, pre_delete_collection=pre_delete_collection, )
[docs] def decode_value(self, value: Any) -> Any: return value
[docs] def encode_value(self, value: Any) -> Any: return value
[docs]@deprecated( since="0.0.22", removal="0.3.0", alternative_import="langchain_astradb.AstraDBByteStore", ) class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
[docs] def __init__( self, collection_name: str, token: Optional[str] = None, api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, namespace: Optional[str] = None, *, async_astra_db_client: Optional[AsyncAstraDB] = None, pre_delete_collection: bool = False, setup_mode: SetupMode = SetupMode.SYNC, ) -> None: """使用DataStax AstraDB作为底层存储的ByteStore实现。 字节值将被转换为base64编码的字符串。 AstraDB集合中的文档将具有以下格式 .. code-block:: json { "_id": "<key>", "value": "<byte64 string value>" } 参数: collection_name: 要创建/使用的Astra DB集合的名称。 token: 用于Astra DB使用的API令牌。 api_endpoint: API端点的完整URL, 例如 `https://<DB-ID>-us-east1.apps.astra.datastax.com`。 astra_db_client: *token+api_endpoint的替代*, 您可以传递一个已创建的'astrapy.db.AstraDB'实例。 async_astra_db_client: *token+api_endpoint的替代*, 您可以传递一个已创建的'astrapy.db.AsyncAstraDB'实例。 namespace: 创建集合的命名空间(又名键空间)。 默认为数据库的“默认命名空间”。 setup_mode: 用于创建Astra DB集合的模式(SYNC、ASYNC或OFF)。 pre_delete_collection: 是否在创建集合之前删除集合。 如果为False且集合已经存在,则将直接使用该集合。 """ # Constructor doc is not inherited so we have to override it. super().__init__( collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, namespace=namespace, setup_mode=setup_mode, pre_delete_collection=pre_delete_collection, )
[docs] def decode_value(self, value: Any) -> Optional[bytes]: if value is None: return None return base64.b64decode(value)
[docs] def encode_value(self, value: Optional[bytes]) -> Any: if value is None: return None return base64.b64encode(value).decode("ascii")