from __future__ import annotations
import base64
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from langchain_core._api.deprecation import deprecated
from langchain_core.stores import BaseStore, ByteStore
from langchain_community.utilities.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
if TYPE_CHECKING:
from astrapy.db import AstraDB, AsyncAstraDB
V = TypeVar("V")
[docs]class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
"""用于DataStax AstraDB数据存储的基类。"""
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None:
self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
[docs] @abstractmethod
def decode_value(self, value: Any) -> Optional[V]:
"""从Astra DB解码值"""
[docs] @abstractmethod
def encode_value(self, value: Optional[V]) -> Any:
"""为Astra DB编码数值"""
[docs] def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
self.astra_env.ensure_db_setup()
docs_dict = {}
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]
[docs] async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
await self.astra_env.aensure_db_setup()
docs_dict = {}
async for doc in self.async_collection.paginated_find(
filter={"_id": {"$in": list(keys)}}
):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]
[docs] def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
self.astra_env.ensure_db_setup()
for k, v in key_value_pairs:
self.collection.upsert({"_id": k, "value": self.encode_value(v)})
[docs] async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
await self.astra_env.aensure_db_setup()
for k, v in key_value_pairs:
await self.async_collection.upsert(
{"_id": k, "value": self.encode_value(v)}
)
[docs] def mdelete(self, keys: Sequence[str]) -> None:
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"_id": {"$in": list(keys)}})
[docs] async def amdelete(self, keys: Sequence[str]) -> None:
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})
[docs] def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
self.astra_env.ensure_db_setup()
docs = self.collection.paginated_find()
for doc in docs:
key = doc["_id"]
if not prefix or key.startswith(prefix):
yield key
[docs] async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
await self.astra_env.aensure_db_setup()
async for doc in self.async_collection.paginated_find():
key = doc["_id"]
if not prefix or key.startswith(prefix):
yield key
[docs]@deprecated(
since="0.0.22",
removal="0.3.0",
alternative_import="langchain_astradb.AstraDBStore",
)
class AstraDBStore(AstraDBBaseStore[Any]):
[docs] def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
*,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
"""使用DataStax AstraDB作为基础存储的BaseStore实现。
值的类型可以是任何可以通过json.dumps序列化的类型。
可用于存储与CacheBackedEmbeddings一起使用的嵌入。
AstraDB集合中的文档将具有以下格式
.. code-block:: json
{
"_id": "<key>",
"value": <value>
}
参数:
collection_name: 要创建/使用的Astra DB集合的名称。
token: 用于Astra DB使用的API令牌。
api_endpoint: API端点的完整URL,
例如 `https://<DB-ID>-us-east1.apps.astra.datastax.com`。
astra_db_client: *token+api_endpoint的替代*,
您可以传递一个已创建的'astrapy.db.AstraDB'实例。
async_astra_db_client: *token+api_endpoint的替代*,
您可以传递一个已创建的'astrapy.db.AsyncAstraDB'实例。
namespace: 创建集合的命名空间(又名键空间)。
默认为数据库的“默认命名空间”。
setup_mode: 用于创建Astra DB集合的模式(SYNC、ASYNC或OFF)。
pre_delete_collection: 是否在创建集合之前删除集合。
如果为False且集合已经存在,则将直接使用该集合。
"""
# Constructor doc is not inherited so we have to override it.
super().__init__(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
[docs] def decode_value(self, value: Any) -> Any:
return value
[docs] def encode_value(self, value: Any) -> Any:
return value
[docs]@deprecated(
since="0.0.22",
removal="0.3.0",
alternative_import="langchain_astradb.AstraDBByteStore",
)
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
[docs] def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
*,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
"""使用DataStax AstraDB作为底层存储的ByteStore实现。
字节值将被转换为base64编码的字符串。
AstraDB集合中的文档将具有以下格式
.. code-block:: json
{
"_id": "<key>",
"value": "<byte64 string value>"
}
参数:
collection_name: 要创建/使用的Astra DB集合的名称。
token: 用于Astra DB使用的API令牌。
api_endpoint: API端点的完整URL,
例如 `https://<DB-ID>-us-east1.apps.astra.datastax.com`。
astra_db_client: *token+api_endpoint的替代*,
您可以传递一个已创建的'astrapy.db.AstraDB'实例。
async_astra_db_client: *token+api_endpoint的替代*,
您可以传递一个已创建的'astrapy.db.AsyncAstraDB'实例。
namespace: 创建集合的命名空间(又名键空间)。
默认为数据库的“默认命名空间”。
setup_mode: 用于创建Astra DB集合的模式(SYNC、ASYNC或OFF)。
pre_delete_collection: 是否在创建集合之前删除集合。
如果为False且集合已经存在,则将直接使用该集合。
"""
# Constructor doc is not inherited so we have to override it.
super().__init__(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
[docs] def decode_value(self, value: Any) -> Optional[bytes]:
if value is None:
return None
return base64.b64decode(value)
[docs] def encode_value(self, value: Optional[bytes]) -> Any:
if value is None:
return None
return base64.b64encode(value).decode("ascii")