Source code for langchain_community.cache

"""
.. 警告::
  Beta 功能!

**Cache**  为 LLMs 提供了一个可选的缓存层。

Cache 有两个用处:

- 如果你经常多次请求相同的完成,它可以通过减少向 LLM 提供商发出的 API 调用次数来为您节省金钱。
- 它可以通过减少向 LLM 提供商发出的 API 调用次数来加快应用程序的速度。

Cache 与 Memory 直接竞争。请参阅文档以了解优缺点。

**类层次结构:** 

.. code-block::

    BaseCache --> <name>Cache  # 例如: InMemoryCache, RedisCache, GPTCache
"""

from __future__ import annotations

import hashlib
import inspect
import json
import logging
import uuid
import warnings
from abc import ABC
from datetime import timedelta
from enum import Enum
from functools import lru_cache, wraps
from typing import (
    TYPE_CHECKING,
    Any,
    Awaitable,
    Callable,
    Dict,
    Generator,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    cast,
)

from sqlalchemy import Column, Integer, String, create_engine, delete, select
from sqlalchemy.engine import Row
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session

from langchain_community.utilities.cassandra import SetupMode as CassandraSetupMode
from langchain_community.vectorstores.azure_cosmos_db import (
    CosmosDBSimilarityType,
    CosmosDBVectorSearchType,
)

try:
    from sqlalchemy.orm import declarative_base
except ImportError:
    from sqlalchemy.ext.declarative import declarative_base

from langchain_core._api.deprecation import deprecated, warn_deprecated
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils import get_from_env

from langchain_community.utilities.astradb import (
    SetupMode as AstraSetupMode,
)
from langchain_community.utilities.astradb import (
    _AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
from langchain_community.vectorstores import (
    OpenSearchVectorSearch as OpenSearchVectorStore,
)
from langchain_community.vectorstores.redis import Redis as RedisVectorstore

logger = logging.getLogger(__file__)

if TYPE_CHECKING:
    import momento
    from astrapy.db import AstraDB, AsyncAstraDB
    from cassandra.cluster import Session as CassandraSession


def _hash(_input: str) -> str:
    """使用确定性哈希方法。"""
    return hashlib.md5(_input.encode()).hexdigest()


def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
    """将生成的内容转储为json。

参数:
    generations(RETURN_VAL_TYPE):语言模型生成的列表。

返回:
    str:表示生成列表的Json。

警告:不适用于“Generation”的任意子类。
"""
    return json.dumps([generation.dict() for generation in generations])


def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
    """从json中加载生成。

参数:
    generations_json(str):表示生成列表的json字符串。

引发:
    ValueError:无法将json字符串解码为生成列表。

返回:
    RETURN_VAL_TYPE:生成列表。

警告:与“Generation”的任意子类不兼容。
"""
    try:
        results = json.loads(generations_json)
        return [Generation(**generation_dict) for generation_dict in results]
    except json.JSONDecodeError:
        raise ValueError(
            f"Could not decode json to list of generations: {generations_json}"
        )


def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
    """为通用的RETURN_VAL_TYPE进行序列化,即`Generation`的序列

参数:
    generations (RETURN_VAL_TYPE): 一个语言模型生成的列表。

返回:
    str: 代表生成列表的单个字符串。

这个函数(及其对应的`_loads_generations`)依赖于带有Reviver的dumps/loads对,因此能够处理Generation的所有子类。

列表中的每个项目都可以被`dumps`为一个字符串,
然后我们将整个字符串列表转换为json-dumped。
"""
    return json.dumps([dumps(_item) for _item in generations])


def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
    """将字符串反序列化为通用的RETURN_VAL_TYPE(即`Generation`序列)。

请参阅`_dumps_generations`,这是该函数的反函数。

参数:
    generations_str(str):表示一组generations的字符串。

与旧缓存blob格式兼容
不会因为格式错误的条目而引发异常,只会记录警告并返回None:调用者应该为这种缓存未命中做好准备。

返回:
    RETURN_VAL_TYPE:一组generations的列表。
"""
    try:
        generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
        return generations
    except (json.JSONDecodeError, TypeError):
        # deferring the (soft) handling to after the legacy-format attempt
        pass

    try:
        gen_dicts = json.loads(generations_str)
        # not relying on `_load_generations_from_json` (which could disappear):
        generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
        logger.warning(
            f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
        )
        return generations
    except (json.JSONDecodeError, TypeError):
        logger.warning(
            f"Malformed/unparsable cached blob encountered: '{generations_str}'"
        )
        return None


[docs]class InMemoryCache(BaseCache): """在内存中存储东西的缓存。"""
[docs] def __init__(self) -> None: """使用空缓存进行初始化。""" self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。""" return self._cache.get((prompt, llm_string), None)
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """根据提示和llm_string更新缓存。""" self._cache[(prompt, llm_string)] = return_val
[docs] def clear(self, **kwargs: Any) -> None: """清除缓存。""" self._cache = {}
[docs] async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。""" return self.lookup(prompt, llm_string)
[docs] async def aupdate( self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE ) -> None: """根据提示和llm_string更新缓存。""" self.update(prompt, llm_string, return_val)
[docs] async def aclear(self, **kwargs: Any) -> None: """清除缓存。""" self.clear()
Base = declarative_base()
[docs]class FullLLMCache(Base): # type: ignore """SQLite表格用于完整的LLM缓存(所有代)。 """ __tablename__ = "full_llm_cache" prompt = Column(String, primary_key=True) llm = Column(String, primary_key=True) idx = Column(Integer, primary_key=True) response = Column(String)
[docs]class SQLAlchemyCache(BaseCache): """使用SQLAlchemy作为后端的缓存。"""
[docs] def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache): """通过创建所有表来初始化。""" self.engine = engine self.cache_schema = cache_schema self.cache_schema.metadata.create_all(self.engine)
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。""" stmt = ( select(self.cache_schema.response) .where(self.cache_schema.prompt == prompt) # type: ignore .where(self.cache_schema.llm == llm_string) .order_by(self.cache_schema.idx) ) with Session(self.engine) as session: rows = session.execute(stmt).fetchall() if rows: try: return [loads(row[0]) for row in rows] except Exception: logger.warning( "Retrieving a cache value that could not be deserialized " "properly. This is likely due to the cache being in an " "older format. Please recreate your cache to avoid this " "error." ) # In a previous life we stored the raw text directly # in the table, so assume it's in that format. return [Generation(text=row[0]) for row in rows] return None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """根据提示和llm_string更新。""" items = [ self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i) for i, gen in enumerate(return_val) ] with Session(self.engine) as session, session.begin(): for item in items: session.merge(item)
[docs] def clear(self, **kwargs: Any) -> None: """清除缓存。""" with Session(self.engine) as session: session.query(self.cache_schema).delete() session.commit()
[docs]class SQLiteCache(SQLAlchemyCache): """使用SQLite作为后端的缓存。"""
[docs] def __init__(self, database_path: str = ".langchain.db"): """通过创建引擎和所有表来进行初始化。""" engine = create_engine(f"sqlite:///{database_path}") super().__init__(engine)
[docs]class UpstashRedisCache(BaseCache): """使用Upstash Redis作为后端的缓存。"""
[docs] def __init__(self, redis_: Any, *, ttl: Optional[int] = None): """初始化一个 UpstashRedisCache 的实例。 该方法使用 Upstash Redis 缓存功能初始化一个对象。 它接受一个 `redis_` 参数,该参数应该是一个 Upstash Redis 客户端类的实例,允许对象与 Upstash Redis 服务器进行交互以进行缓存操作。 参数: redis_: Upstash Redis 客户端类的实例 (例如,Redis) 用于缓存。 这允许对象与 Redis 服务器进行通信 以进行缓存操作。 ttl (int, optional): 缓存项的生存时间(TTL)(以秒为单位)。 如果提供,它设置缓存项保持有效的时间段。 如果不提供,缓存项将不会 自动过期。 """ try: from upstash_redis import Redis except ImportError: raise ImportError( "Could not import upstash_redis python package. " "Please install it with `pip install upstash_redis`." ) if not isinstance(redis_, Redis): raise ValueError("Please pass in Upstash Redis object.") self.redis = redis_ self.ttl = ttl
def _key(self, prompt: str, llm_string: str) -> str: """根据提示和llm_string计算密钥""" return _hash(prompt + llm_string)
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。""" generations = [] # Read from a HASH results = self.redis.hgetall(self._key(prompt, llm_string)) if results: for _, text in results.items(): generations.append(Generation(text=text)) return generations if generations else None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """根据提示和llm_string更新缓存。""" for gen in return_val: if not isinstance(gen, Generation): raise ValueError( "UpstashRedisCache supports caching of normal LLM generations, " f"got {type(gen)}" ) if isinstance(gen, ChatGeneration): warnings.warn( "NOTE: Generation has not been cached. UpstashRedisCache does not" " support caching ChatModel outputs." ) return # Write to a HASH key = self._key(prompt, llm_string) mapping = { str(idx): generation.text for idx, generation in enumerate(return_val) } self.redis.hset(key=key, values=mapping) if self.ttl is not None: self.redis.expire(key, self.ttl)
[docs] def clear(self, **kwargs: Any) -> None: """清除缓存。如果`asynchronous`为True,则异步刷新。这将刷新*整个*数据库。 """ asynchronous = kwargs.get("asynchronous", False) if asynchronous: asynchronous = "ASYNC" else: asynchronous = "SYNC" self.redis.flushdb(flush_type=asynchronous)
class _RedisCacheBase(BaseCache, ABC): @staticmethod def _key(prompt: str, llm_string: str) -> str: """根据提示和llm_string计算密钥""" return _hash(prompt + llm_string) @staticmethod def _ensure_generation_type(return_val: RETURN_VAL_TYPE) -> None: for gen in return_val: if not isinstance(gen, Generation): raise ValueError( "RedisCache only supports caching of normal LLM generations, " f"got {type(gen)}" ) @staticmethod def _get_generations( results: dict[str | bytes, str | bytes], ) -> Optional[List[Generation]]: generations = [] if results: for _, text in results.items(): try: generations.append(loads(cast(str, text))) except Exception: logger.warning( "Retrieving a cache value that could not be deserialized " "properly. This is likely due to the cache being in an " "older format. Please recreate your cache to avoid this " "error." ) # In a previous life we stored the raw text directly # in the table, so assume it's in that format. generations.append(Generation(text=text)) # type: ignore[arg-type] return generations if generations else None @staticmethod def _configure_pipeline_for_update( key: str, pipe: Any, return_val: RETURN_VAL_TYPE, ttl: Optional[int] = None ) -> None: pipe.hset( key, mapping={ str(idx): dumps(generation) for idx, generation in enumerate(return_val) }, ) if ttl is not None: pipe.expire(key, ttl)
[docs]class RedisCache(_RedisCacheBase): """使用Redis作为后端的缓存。允许使用同步的`redis.Redis`客户端。"""
[docs] def __init__(self, redis_: Any, *, ttl: Optional[int] = None): """初始化一个RedisCache的实例。 该方法使用Redis缓存功能初始化一个对象。 它接受一个`redis_`参数,应该是一个Redis客户端类的实例(`redis.Redis`),允许对象与Redis服务器进行交互以进行缓存操作。 参数: redis_ (Any): 一个Redis客户端类(`redis.Redis`)的实例,用于缓存。 这允许对象与Redis服务器进行通信以进行缓存操作。 ttl (int, optional): 缓存项的生存时间(TTL),以秒为单位。 如果提供,它设置了缓存项有效的时间段。如果不提供,缓存项将不会自动过期。 """ try: from redis import Redis except ImportError: raise ImportError( "Could not import `redis` python package. " "Please install it with `pip install redis`." ) if not isinstance(redis_, Redis): raise ValueError("Please pass a valid `redis.Redis` client.") self.redis = redis_ self.ttl = ttl
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。""" # Read from a Redis HASH try: results = self.redis.hgetall(self._key(prompt, llm_string)) return self._get_generations(results) # type: ignore[arg-type] except Exception as e: logger.error(f"Redis lookup failed: {e}") return None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """根据提示和llm_string更新缓存。""" self._ensure_generation_type(return_val) key = self._key(prompt, llm_string) try: with self.redis.pipeline() as pipe: self._configure_pipeline_for_update(key, pipe, return_val, self.ttl) pipe.execute() except Exception as e: logger.error(f"Redis update failed: {e}")
[docs] def clear(self, **kwargs: Any) -> None: """清除缓存。如果`asynchronous`为True,则异步刷新。 异步版本。 """ try: asynchronous = kwargs.get("asynchronous", False) self.redis.flushdb(asynchronous=asynchronous, **kwargs) except Exception as e: logger.error(f"Redis clear failed: {e}")
[docs]class AsyncRedisCache(_RedisCacheBase): """使用Redis作为后端的缓存。允许使用异步的`redis.asyncio.Redis`客户端。"""
[docs] def __init__(self, redis_: Any, *, ttl: Optional[int] = None): """初始化一个AsyncRedisCache的实例。 该方法用于初始化一个具有Redis缓存功能的对象。 它接受一个`redis_`参数,该参数应该是Redis客户端类的一个实例(`redis.asyncio.Redis`),允许对象与Redis服务器进行交互以进行缓存操作。 参数: redis_ (Any): Redis客户端类的一个实例(`redis.asyncio.Redis`), 用于缓存操作。 这允许对象与Redis服务器进行通信以进行缓存操作。 ttl (int, optional): 缓存项的生存时间(TTL),以秒为单位。 如果提供,它将设置缓存项保持有效的时间长度。 如果不提供,缓存项将不会自动过期。 """ try: from redis.asyncio import Redis except ImportError: raise ImportError( "Could not import `redis.asyncio` python package. " "Please install it with `pip install redis`." ) if not isinstance(redis_, Redis): raise ValueError("Please pass a valid `redis.asyncio.Redis` client.") self.redis = redis_ self.ttl = ttl
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。""" raise NotImplementedError( "This async Redis cache does not implement `lookup()` method. " "Consider using the async `alookup()` version." )
[docs] async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。异步版本。""" try: results = await self.redis.hgetall(self._key(prompt, llm_string)) return self._get_generations(results) # type: ignore[arg-type] except Exception as e: logger.error(f"Redis async lookup failed: {e}") return None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """根据提示和llm_string更新缓存。""" raise NotImplementedError( "This async Redis cache does not implement `update()` method. " "Consider using the async `aupdate()` version." )
[docs] async def aupdate( self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE ) -> None: """根据提示和llm_string更新缓存。异步版本。""" self._ensure_generation_type(return_val) key = self._key(prompt, llm_string) try: async with self.redis.pipeline() as pipe: self._configure_pipeline_for_update(key, pipe, return_val, self.ttl) await pipe.execute() # type: ignore[attr-defined] except Exception as e: logger.error(f"Redis async update failed: {e}")
[docs] def clear(self, **kwargs: Any) -> None: """清除缓存。如果`asynchronous`为True,则异步刷新。 异步版本。 """ raise NotImplementedError( "This async Redis cache does not implement `clear()` method. " "Consider using the async `aclear()` version." )
[docs] async def aclear(self, **kwargs: Any) -> None: """ 清除缓存。如果 `asynchronous` 为 True,则异步刷新。 Async version. """ try: asynchronous = kwargs.get("asynchronous", False) await self.redis.flushdb(asynchronous=asynchronous, **kwargs) except Exception as e: logger.error(f"Redis async clear failed: {e}")
[docs]class RedisSemanticCache(BaseCache): """使用Redis作为向量存储后端的缓存。""" # TODO - implement a TTL policy in Redis DEFAULT_SCHEMA = { "content_key": "prompt", "text": [ {"name": "prompt"}, ], "extra": [{"name": "return_val"}, {"name": "llm_string"}], }
[docs] def __init__( self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2 ): """通过传递`init` GPTCache函数进行初始化 参数: redis_url (str): 连接到Redis的URL。 embedding (Embedding): 用于语义编码和搜索的嵌入提供程序。 score_threshold (float, 0.2): 示例: .. code-block:: python from langchain_community.globals import set_llm_cache from langchain_community.cache import RedisSemanticCache from langchain_community.embeddings import OpenAIEmbeddings set_llm_cache(RedisSemanticCache( redis_url="redis://localhost:6379", embedding=OpenAIEmbeddings() )) """ self._cache_dict: Dict[str, RedisVectorstore] = {} self.redis_url = redis_url self.embedding = embedding self.score_threshold = score_threshold
def _index_name(self, llm_string: str) -> str: hashed_index = _hash(llm_string) return f"cache:{hashed_index}" def _get_llm_cache(self, llm_string: str) -> RedisVectorstore: index_name = self._index_name(llm_string) # return vectorstore client for the specific llm string if index_name in self._cache_dict: return self._cache_dict[index_name] # create new vectorstore client for the specific llm string try: self._cache_dict[index_name] = RedisVectorstore.from_existing_index( embedding=self.embedding, index_name=index_name, redis_url=self.redis_url, schema=cast(Dict, self.DEFAULT_SCHEMA), ) except ValueError: redis = RedisVectorstore( embedding=self.embedding, index_name=index_name, redis_url=self.redis_url, index_schema=cast(Dict, self.DEFAULT_SCHEMA), ) _embedding = self.embedding.embed_query(text="test") redis._create_index_if_not_exist(dim=len(_embedding)) self._cache_dict[index_name] = redis return self._cache_dict[index_name]
[docs] def clear(self, **kwargs: Any) -> None: """清除给定llm_string的语义缓存。""" index_name = self._index_name(kwargs["llm_string"]) if index_name in self._cache_dict: self._cache_dict[index_name].drop_index( index_name=index_name, delete_documents=True, redis_url=self.redis_url ) del self._cache_dict[index_name]
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。""" llm_cache = self._get_llm_cache(llm_string) generations: List = [] # Read from a Hash results = llm_cache.similarity_search( query=prompt, k=1, distance_threshold=self.score_threshold, ) if results: for document in results: try: generations.extend(loads(document.metadata["return_val"])) except Exception: logger.warning( "Retrieving a cache value that could not be deserialized " "properly. This is likely due to the cache being in an " "older format. Please recreate your cache to avoid this " "error." ) # In a previous life we stored the raw text directly # in the table, so assume it's in that format. generations.extend( _load_generations_from_json(document.metadata["return_val"]) ) return generations if generations else None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """根据提示和llm_string更新缓存。""" for gen in return_val: if not isinstance(gen, Generation): raise ValueError( "RedisSemanticCache only supports caching of " f"normal LLM generations, got {type(gen)}" ) llm_cache = self._get_llm_cache(llm_string) metadata = { "llm_string": llm_string, "prompt": prompt, "return_val": dumps([g for g in return_val]), } llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
[docs]class GPTCache(BaseCache): """使用GPTCache作为后端的缓存。"""
[docs] def __init__( self, init_func: Union[ Callable[[Any, str], None], Callable[[Any], None], None ] = None, ): """通过传递init函数进行初始化(默认值:`None`)。 参数: init_func(Optional[Callable[[Any], None]]):初始化`GPTCache`函数(默认值:`None`) 示例: .. code-block:: python # 使用自定义init函数初始化GPTCache import gptcache from gptcache.processor.pre import get_prompt from gptcache.manager.factory import get_data_manager from langchain_community.globals import set_llm_cache # 避免多个缓存使用相同的文件, 导致不同的llm模型缓存相互影响 def init_gptcache(cache_obj: gptcache.Cache, llm str): cache_obj.init( pre_embedding_func=get_prompt, data_manager=manager_factory( manager="map", data_dir=f"map_cache_{llm}" ), ) set_llm_cache(GPTCache(init_gptcache)) """ try: import gptcache # noqa: F401 except ImportError: raise ImportError( "Could not import gptcache python package. " "Please install it with `pip install gptcache`." ) self.init_gptcache_func: Union[ Callable[[Any, str], None], Callable[[Any], None], None ] = init_func self.gptcache_dict: Dict[str, Any] = {}
def _new_gptcache(self, llm_string: str) -> Any: """新的gptcache对象""" from gptcache import Cache from gptcache.manager.factory import get_data_manager from gptcache.processor.pre import get_prompt _gptcache = Cache() if self.init_gptcache_func is not None: sig = inspect.signature(self.init_gptcache_func) if len(sig.parameters) == 2: self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg] else: self.init_gptcache_func(_gptcache) # type: ignore[call-arg] else: _gptcache.init( pre_embedding_func=get_prompt, data_manager=get_data_manager(data_path=llm_string), ) self.gptcache_dict[llm_string] = _gptcache return _gptcache def _get_gptcache(self, llm_string: str) -> Any: """获取一个缓存对象。 当对应的llm模型缓存不存在时,将会被创建。 """ _gptcache = self.gptcache_dict.get(llm_string, None) if not _gptcache: _gptcache = self._new_gptcache(llm_string) return _gptcache
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """查找缓存数据。 首先,使用`llm_string`参数检索相应的缓存对象, 然后根据`prompt`从缓存中检索数据。 """ from gptcache.adapter.api import get _gptcache = self._get_gptcache(llm_string) res = get(prompt, cache_obj=_gptcache) return _loads_generations(res) if res is not None else None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """更新缓存。 首先,使用`llm_string`参数检索相应的缓存对象, 然后将`prompt`和`return_val`存储在缓存对象中。 """ for gen in return_val: if not isinstance(gen, Generation): raise ValueError( "GPTCache only supports caching of normal LLM generations, " f"got {type(gen)}" ) from gptcache.adapter.api import put _gptcache = self._get_gptcache(llm_string) handled_data = _dumps_generations(return_val) put(prompt, handled_data, cache_obj=_gptcache) return None
[docs] def clear(self, **kwargs: Any) -> None: """清除缓存。""" from gptcache import Cache for gptcache_instance in self.gptcache_dict.values(): gptcache_instance = cast(Cache, gptcache_instance) gptcache_instance.flush() self.gptcache_dict.clear()
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None: """如果缓存不存在,则创建缓存。 抛出: SdkException:Momento 服务或网络错误 Exception:意外响应 """ from momento.responses import CreateCache create_cache_response = cache_client.create_cache(cache_name) if isinstance(create_cache_response, CreateCache.Success) or isinstance( create_cache_response, CreateCache.CacheAlreadyExists ): return None elif isinstance(create_cache_response, CreateCache.Error): raise create_cache_response.inner_exception else: raise Exception(f"Unexpected response cache creation: {create_cache_response}") def _validate_ttl(ttl: Optional[timedelta]) -> None: if ttl is not None and ttl <= timedelta(seconds=0): raise ValueError(f"ttl must be positive but was {ttl}.")
[docs]class MomentoCache(BaseCache): """使用Momento作为后端的缓存。请参阅https://gomomento.com/"""
[docs] def __init__( self, cache_client: momento.CacheClient, cache_name: str, *, ttl: Optional[timedelta] = None, ensure_cache_exists: bool = True, ): """实例化一个使用Momento作为后端的提示缓存。 注意:要实例化传递给MomentoCache的缓存客户端,您必须拥有Momento帐户。请参阅https://gomomento.com/。 参数: cache_client(CacheClient):Momento缓存客户端。 cache_name(str):用于存储数据的缓存的名称。 ttl(Optional[timedelta],可选):缓存项的生存时间。默认为None,即使用客户端默认的TTL。 ensure_cache_exists(bool,可选):如果缓存不存在,则创建缓存。默认为True。 引发: ImportError:未安装Momento python包。 TypeError:cache_client不是momento.CacheClientObject类型。 ValueError:ttl既不是空值也不是非负值。 """ try: from momento import CacheClient except ImportError: raise ImportError( "Could not import momento python package. " "Please install it with `pip install momento`." ) if not isinstance(cache_client, CacheClient): raise TypeError("cache_client must be a momento.CacheClient object.") _validate_ttl(ttl) if ensure_cache_exists: _ensure_cache_exists(cache_client, cache_name) self.cache_client = cache_client self.cache_name = cache_name self.ttl = ttl
[docs] @classmethod def from_client_params( cls, cache_name: str, ttl: timedelta, *, configuration: Optional[momento.config.Configuration] = None, api_key: Optional[str] = None, auth_token: Optional[str] = None, # for backwards compatibility **kwargs: Any, ) -> MomentoCache: """从CacheClient参数构建缓存。""" try: from momento import CacheClient, Configurations, CredentialProvider except ImportError: raise ImportError( "Could not import momento python package. " "Please install it with `pip install momento`." ) if configuration is None: configuration = Configurations.Laptop.v1() # Try checking `MOMENTO_AUTH_TOKEN` first for backwards compatibility try: api_key = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN") except ValueError: api_key = api_key or get_from_env("api_key", "MOMENTO_API_KEY") credentials = CredentialProvider.from_string(api_key) cache_client = CacheClient(configuration, credentials, default_ttl=ttl) return cls(cache_client, cache_name, ttl=ttl, **kwargs)
def __key(self, prompt: str, llm_string: str) -> str: """根据提示和相关模型以及设置计算缓存键。 参数: prompt (str): 通过语言模型运行的提示。 llm_string (str): 语言模型版本和设置。 返回: str: 缓存键。 """ return _hash(prompt + llm_string)
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """在缓存中通过提示和关联的模型和设置查找LLM生成。 参数: prompt(str):通过语言模型运行的提示。 llm_string(str):语言模型版本和设置。 引发: SdkException:Momento服务或网络错误 返回: Optional[RETURN_VAL_TYPE]:语言模型生成的列表。 """ from momento.responses import CacheGet generations: RETURN_VAL_TYPE = [] get_response = self.cache_client.get( self.cache_name, self.__key(prompt, llm_string) ) if isinstance(get_response, CacheGet.Hit): value = get_response.value_string generations = _load_generations_from_json(value) elif isinstance(get_response, CacheGet.Miss): pass elif isinstance(get_response, CacheGet.Error): raise get_response.inner_exception return generations if generations else None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """将llm生成存储在缓存中。 参数: prompt(str):通过语言模型运行的提示。 llm_string(str):语言模型字符串。 return_val(RETURN_VAL_TYPE):语言模型生成列表。 引发: SdkException:Momento服务或网络错误 Exception:意外响应 """ for gen in return_val: if not isinstance(gen, Generation): raise ValueError( "Momento only supports caching of normal LLM generations, " f"got {type(gen)}" ) key = self.__key(prompt, llm_string) value = _dump_generations_to_json(return_val) set_response = self.cache_client.set(self.cache_name, key, value, self.ttl) from momento.responses import CacheSet if isinstance(set_response, CacheSet.Success): pass elif isinstance(set_response, CacheSet.Error): raise set_response.inner_exception else: raise Exception(f"Unexpected response: {set_response}")
[docs] def clear(self, **kwargs: Any) -> None: """清除缓存。 抛出: SdkException: Momento 服务或网络错误 """ from momento.responses import CacheFlush flush_response = self.cache_client.flush_cache(self.cache_name) if isinstance(flush_response, CacheFlush.Success): pass elif isinstance(flush_response, CacheFlush.Error): raise flush_response.inner_exception
CASSANDRA_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_cache" CASSANDRA_CACHE_DEFAULT_TTL_SECONDS = None
[docs]class CassandraCache(BaseCache): """使用Cassandra / Astra DB作为后端的缓存。 示例: .. code-block:: python import cassio from langchain_community.cache import CassandraCache from langchain_core.globals import set_llm_cache cassio.init(auto=True) # 需要环境变量,请参阅CassIO文档 set_llm_cache(CassandraCache()) 它使用单个Cassandra表。 查找键(用于形成主键)为: - prompt,一个字符串 - llm_string,模型参数的确定性str表示。 (需要防止相同提示不同模型的冲突) 参数: session:一个打开的Cassandra会话。 留空以使用全局cassio初始化(见下文) keyspace:用于存储缓存的keyspace。 留空以使用全局cassio初始化(见下文) table_name:用作缓存的Cassandra表的名称 ttl_seconds:缓存条目的生存时间 (默认值:None,即永久) setup_mode:langchain_community.utilities.cassandra.SetupMode中的一个值。 选择SYNC、ASYNC和OFF之间的值 - 如果Cassandra表已经存在,则选择OFF,以加快初始化速度。 注意: 当省略session和keyspace参数(或传递为None)时, 如果有全局可用的cassio设置,则会回退到全局可用的cassio设置。 换句话说,如果先前在代码的任何地方执行过'cassio.init(...)', 则基于Cassandra的对象根本不需要指定连接参数。"""
[docs] def __init__( self, session: Optional[CassandraSession] = None, keyspace: Optional[str] = None, table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME, ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS, skip_provisioning: bool = False, setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC, ): if skip_provisioning: warn_deprecated( "0.0.33", name="skip_provisioning", alternative=( "setup_mode=langchain_community.utilities.cassandra.SetupMode.OFF" ), pending=True, ) try: from cassio.table import ElasticCassandraTable except (ImportError, ModuleNotFoundError): raise ImportError( "Could not import cassio python package. " "Please install it with `pip install -U cassio`." ) self.session = session self.keyspace = keyspace self.table_name = table_name self.ttl_seconds = ttl_seconds kwargs = {} if setup_mode == CassandraSetupMode.ASYNC: kwargs["async_setup"] = True self.kv_cache = ElasticCassandraTable( session=self.session, keyspace=self.keyspace, table=self.table_name, keys=["llm_string", "prompt"], primary_key_type=["TEXT", "TEXT"], ttl_seconds=self.ttl_seconds, skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF, **kwargs, )
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: item = self.kv_cache.get( llm_string=_hash(llm_string), prompt=_hash(prompt), ) if item is not None: return _loads_generations(item["body_blob"]) else: return None
[docs] async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: item = await self.kv_cache.aget( llm_string=_hash(llm_string), prompt=_hash(prompt), ) if item is not None: return _loads_generations(item["body_blob"]) else: return None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: blob = _dumps_generations(return_val) self.kv_cache.put( llm_string=_hash(llm_string), prompt=_hash(prompt), body_blob=blob, )
[docs] async def aupdate( self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE ) -> None: blob = _dumps_generations(return_val) await self.kv_cache.aput( llm_string=_hash(llm_string), prompt=_hash(prompt), body_blob=blob, )
[docs] def delete_through_llm( self, prompt: str, llm: LLM, stop: Optional[List[str]] = None ) -> None: """一个围绕LLM传递的`delete`的包装器。 如果llm.invoke(prompt)调用有一个`stop`参数,你应该在这里传递它。 """ llm_string = get_prompts( {**llm.dict(), **{"stop": stop}}, [], )[1] return self.delete(prompt, llm_string=llm_string)
[docs] def delete(self, prompt: str, llm_string: str) -> None: """如果存在条目,则从缓存中驱逐。""" return self.kv_cache.delete( llm_string=_hash(llm_string), prompt=_hash(prompt), )
[docs] def clear(self, **kwargs: Any) -> None: """清除缓存。这是一次性清除所有LLMs的缓存。""" self.kv_cache.clear()
[docs] async def aclear(self, **kwargs: Any) -> None: """清除缓存。这是一次性清除所有LLMs的缓存。""" await self.kv_cache.aclear()
# This constant is in fact a similarity - the 'distance' name is kept for compatibility: CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot" CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85 CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_semantic_cache" CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS = None CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
[docs]class CassandraSemanticCache(BaseCache): """使用Cassandra作为语义(即基于相似性的)查找的向量存储后端的缓存。 示例: .. code-block:: python import cassio from langchain_community.cache import CassandraSemanticCache from langchain_core.globals import set_llm_cache cassio.init(auto=True) # 需要环境变量,请参阅CassIO文档 my_embedding = ... set_llm_cache(CassandraSemanticCache( embedding=my_embedding, table_name="my_semantic_cache", )) 它使用单个(向量)Cassandra表,原则上存储来自多个LLM的缓存值,因此LLM的llm_string是行的主键的一部分。 可以选择相似性度量(默认值为“dot”表示点积)。选择另一个度量(“cos”,“l2”)几乎肯定需要阈值调整。(即使坚持使用“dot”,也可能需要这样做)。 参数: session:一个打开的Cassandra会话。 留空以使用全局cassio init(见下文) keyspace:用于存储缓存的键空间。 留空以使用全局cassio init(见下文) embedding:用于语义编码和搜索的嵌入提供程序。 table_name:用作缓存的Cassandra(向量)表的名称。对于“简单”用法有一个默认值,但请记住,如果应用程序中存在多个嵌入模型,则明确指定不同的表(它们不能共享一个缓存表)。 distance_metric:'similarity_measure'参数的别名(见下文)。 由于“distance”术语具有误导性,请优先选择“similarity_measure”以确保清晰性。 score_threshold:用作相似性搜索的截止值的数值。 ttl_seconds:缓存条目的存活时间(默认值:无,即永久) similarity_measure:用于相似性搜索的度量。 注意:此参数由'distance_metric'别名 - 但建议使用“similarity”术语,因为该值实际上是相似性(即更高表示更接近)。 请注意,'distance_metric'和'similarity_measure'两个参数中最多只能提供一个。 setup_mode:langchain_community.utilities.cassandra.SetupMode中的一个值。 选择SYNC、ASYNC和OFF之间的值 - 如果Cassandra表已经存在,则选择OFF以进行更快的初始化。 注意: 当session和keyspace参数被省略(或传递为None)时,如果有任何全局可用的cassio设置,则会回退到这些设置。 换句话说,如果先前在代码的任何地方执行过'cassio.init(...)',则基于Cassandra的对象无需指定连接参数。"""
[docs] def __init__( self, session: Optional[CassandraSession] = None, keyspace: Optional[str] = None, embedding: Optional[Embeddings] = None, table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME, distance_metric: Optional[str] = None, score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD, ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS, skip_provisioning: bool = False, similarity_measure: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC, setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC, ): if skip_provisioning: warn_deprecated( "0.0.33", name="skip_provisioning", alternative=( "setup_mode=langchain_community.utilities.cassandra.SetupMode.OFF" ), pending=True, ) try: from cassio.table import MetadataVectorCassandraTable except (ImportError, ModuleNotFoundError): raise ImportError( "Could not import cassio python package. " "Please install it with `pip install -U cassio`." ) if not embedding: raise ValueError("Missing required parameter 'embedding'.") # detect if legacy 'distance_metric' parameter used if distance_metric is not None: # if passed, takes precedence over 'similarity_measure', but we warn: warn_deprecated( "0.0.33", name="distance_metric", alternative="similarity_measure", pending=True, ) similarity_measure = distance_metric self.session = session self.keyspace = keyspace self.embedding = embedding self.table_name = table_name self.similarity_measure = similarity_measure self.score_threshold = score_threshold self.ttl_seconds = ttl_seconds # The contract for this class has separate lookup and update: # in order to spare some embedding calculations we cache them between # the two calls. # Note: each instance of this class has its own `_get_embedding` with # its own lru. @lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) def _cache_embedding(text: str) -> List[float]: return self.embedding.embed_query(text=text) self._get_embedding = _cache_embedding @_async_lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) async def _acache_embedding(text: str) -> List[float]: return await self.embedding.aembed_query(text=text) self._aget_embedding = _acache_embedding embedding_dimension: Union[int, Awaitable[int], None] = None if setup_mode == CassandraSetupMode.ASYNC: embedding_dimension = self._aget_embedding_dimension() elif setup_mode == CassandraSetupMode.SYNC: embedding_dimension = self._get_embedding_dimension() kwargs = {} if setup_mode == CassandraSetupMode.ASYNC: kwargs["async_setup"] = True self.table = MetadataVectorCassandraTable( session=self.session, keyspace=self.keyspace, table=self.table_name, primary_key_type=["TEXT"], vector_dimension=embedding_dimension, ttl_seconds=self.ttl_seconds, metadata_indexing=("allow", {"_llm_string_hash"}), skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF, **kwargs, )
def _get_embedding_dimension(self) -> int: return len(self._get_embedding(text="This is a sample sentence.")) async def _aget_embedding_dimension(self) -> int: return len(await self._aget_embedding(text="This is a sample sentence."))
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: embedding_vector = self._get_embedding(text=prompt) llm_string_hash = _hash(llm_string) body = _dumps_generations(return_val) metadata = { "_prompt": prompt, "_llm_string_hash": llm_string_hash, } row_id = f"{_hash(prompt)}-{llm_string_hash}" self.table.put( body_blob=body, vector=embedding_vector, row_id=row_id, metadata=metadata, )
[docs] async def aupdate( self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE ) -> None: embedding_vector = await self._aget_embedding(text=prompt) llm_string_hash = _hash(llm_string) body = _dumps_generations(return_val) metadata = { "_prompt": prompt, "_llm_string_hash": llm_string_hash, } row_id = f"{_hash(prompt)}-{llm_string_hash}" await self.table.aput( body_blob=body, vector=embedding_vector, row_id=row_id, metadata=metadata, )
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: hit_with_id = self.lookup_with_id(prompt, llm_string) if hit_with_id is not None: return hit_with_id[1] else: return None
[docs] async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: hit_with_id = await self.alookup_with_id(prompt, llm_string) if hit_with_id is not None: return hit_with_id[1] else: return None
[docs] def lookup_with_id( self, prompt: str, llm_string: str ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: """根据提示和llm_string进行查找。 如果有匹配项,则返回(document_id,cached_entry) """ prompt_embedding: List[float] = self._get_embedding(text=prompt) hits = list( self.table.metric_ann_search( vector=prompt_embedding, metadata={"_llm_string_hash": _hash(llm_string)}, n=1, metric=self.similarity_measure, metric_threshold=self.score_threshold, ) ) if hits: hit = hits[0] generations = _loads_generations(hit["body_blob"]) if generations is not None: # this protects against malformed cached items: return ( hit["row_id"], generations, ) else: return None else: return None
[docs] async def alookup_with_id( self, prompt: str, llm_string: str ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: """根据提示和llm_string进行查找。 如果有匹配项,则返回(document_id,cached_entry) """ prompt_embedding: List[float] = await self._aget_embedding(text=prompt) hits = list( await self.table.ametric_ann_search( vector=prompt_embedding, metadata={"_llm_string_hash": _hash(llm_string)}, n=1, metric=self.similarity_measure, metric_threshold=self.score_threshold, ) ) if hits: hit = hits[0] generations = _loads_generations(hit["body_blob"]) if generations is not None: # this protects against malformed cached items: return ( hit["row_id"], generations, ) else: return None else: return None
[docs] def lookup_with_id_through_llm( self, prompt: str, llm: LLM, stop: Optional[List[str]] = None ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: llm_string = get_prompts( {**llm.dict(), **{"stop": stop}}, [], )[1] return self.lookup_with_id(prompt, llm_string=llm_string)
[docs] async def alookup_with_id_through_llm( self, prompt: str, llm: LLM, stop: Optional[List[str]] = None ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: llm_string = ( await aget_prompts( {**llm.dict(), **{"stop": stop}}, [], ) )[1] return await self.alookup_with_id(prompt, llm_string=llm_string)
[docs] def delete_by_document_id(self, document_id: str) -> None: """考虑到这是一个“相似性搜索”缓存,一个合理的失效模式是首先进行查找以获取一个ID,然后使用该ID进行删除。这是第二步。 """ self.table.delete(row_id=document_id)
[docs] async def adelete_by_document_id(self, document_id: str) -> None: """考虑到这是一个“相似性搜索”缓存,一个合理的失效模式是首先进行查找以获取一个ID,然后使用该ID进行删除。这是第二步。 """ await self.table.adelete(row_id=document_id)
[docs] def clear(self, **kwargs: Any) -> None: """清除*整个*语义缓存。""" self.table.clear()
[docs] async def aclear(self, **kwargs: Any) -> None: """清除*整个*语义缓存。""" await self.table.aclear()
[docs]class FullMd5LLMCache(Base): # type: ignore """SQLite表格用于完整的LLM缓存(所有代)。 """ __tablename__ = "full_md5_llm_cache" id = Column(String, primary_key=True) prompt_md5 = Column(String, index=True) llm = Column(String, index=True) idx = Column(Integer, index=True) prompt = Column(String) response = Column(String)
[docs]class SQLAlchemyMd5Cache(BaseCache): """使用SQLAlchemy作为后端的缓存。"""
[docs] def __init__( self, engine: Engine, cache_schema: Type[FullMd5LLMCache] = FullMd5LLMCache ): """通过创建所有表来初始化。""" self.engine = engine self.cache_schema = cache_schema self.cache_schema.metadata.create_all(self.engine)
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。""" rows = self._search_rows(prompt, llm_string) if rows: return [loads(row[0]) for row in rows] return None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """根据提示和llm_string更新。""" with Session(self.engine) as session, session.begin(): self._delete_previous(session, prompt, llm_string) prompt_md5 = self.get_md5(prompt) items = [ self.cache_schema( id=str(uuid.uuid1()), prompt=prompt, prompt_md5=prompt_md5, llm=llm_string, response=dumps(gen), idx=i, ) for i, gen in enumerate(return_val) ] for item in items: session.merge(item)
def _delete_previous(self, session: Session, prompt: str, llm_string: str) -> None: stmt = ( delete(self.cache_schema) .where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore .where(self.cache_schema.llm == llm_string) .where(self.cache_schema.prompt == prompt) ) session.execute(stmt) def _search_rows(self, prompt: str, llm_string: str) -> Sequence[Row]: prompt_pd5 = self.get_md5(prompt) stmt = ( select(self.cache_schema.response) .where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore .where(self.cache_schema.llm == llm_string) .where(self.cache_schema.prompt == prompt) .order_by(self.cache_schema.idx) ) with Session(self.engine) as session: return session.execute(stmt).fetchall()
[docs] def clear(self, **kwargs: Any) -> None: """清除缓存。""" with Session(self.engine) as session: session.execute(self.cache_schema.delete())
[docs] @staticmethod def get_md5(input_string: str) -> str: return hashlib.md5(input_string.encode()).hexdigest()
ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache"
[docs]@deprecated( since="0.0.28", removal="0.3.0", alternative_import="langchain_astradb.AstraDBCache", ) class AstraDBCache(BaseCache): @staticmethod def _make_id(prompt: str, llm_string: str) -> str: return f"{_hash(prompt)}#{_hash(llm_string)}"
[docs] def __init__( self, *, collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, token: Optional[str] = None, api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, async_astra_db_client: Optional[AsyncAstraDB] = None, namespace: Optional[str] = None, pre_delete_collection: bool = False, setup_mode: AstraSetupMode = AstraSetupMode.SYNC, ): """将Astra DB用作后端的缓存。 它使用单个集合作为kv存储 组合在文档的_id中的查找键是: - prompt,一个字符串 - llm_string,模型参数的确定性str表示。 (用于防止相同提示不同模型的碰撞) 参数: collection_name:要创建/使用的Astra DB集合的名称。 token:用于Astra DB使用的API令牌。 api_endpoint:API端点的完整URL, 例如`https://<DB-ID>-us-east1.apps.astra.datastax.com`。 astra_db_client:*token+api_endpoint的替代方法*, 您可以传递一个已创建的'astrapy.db.AstraDB'实例。 async_astra_db_client:*token+api_endpoint的替代方法*, 您可以传递一个已创建的'astrapy.db.AsyncAstraDB'实例。 namespace:创建集合的命名空间(又名键空间)。 默认为数据库的“默认命名空间”。 setup_mode:用于创建Astra DB集合的模式(SYNC、ASYNC或OFF)。 pre_delete_collection:是否在创建集合之前删除集合。 如果为False且集合已经存在,则将使用现有集合。 """ self.astra_env = _AstraDBCollectionEnvironment( collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, namespace=namespace, setup_mode=setup_mode, pre_delete_collection=pre_delete_collection, ) self.collection = self.astra_env.collection self.async_collection = self.astra_env.async_collection
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: self.astra_env.ensure_db_setup() doc_id = self._make_id(prompt, llm_string) item = self.collection.find_one( filter={ "_id": doc_id, }, projection={ "body_blob": 1, }, )["data"]["document"] return _loads_generations(item["body_blob"]) if item is not None else None
[docs] async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: await self.astra_env.aensure_db_setup() doc_id = self._make_id(prompt, llm_string) item = ( await self.async_collection.find_one( filter={ "_id": doc_id, }, projection={ "body_blob": 1, }, ) )["data"]["document"] return _loads_generations(item["body_blob"]) if item is not None else None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: self.astra_env.ensure_db_setup() doc_id = self._make_id(prompt, llm_string) blob = _dumps_generations(return_val) self.collection.upsert( { "_id": doc_id, "body_blob": blob, }, )
[docs] async def aupdate( self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE ) -> None: await self.astra_env.aensure_db_setup() doc_id = self._make_id(prompt, llm_string) blob = _dumps_generations(return_val) await self.async_collection.upsert( { "_id": doc_id, "body_blob": blob, }, )
[docs] def delete_through_llm( self, prompt: str, llm: LLM, stop: Optional[List[str]] = None ) -> None: """一个围绕LLM传递的`delete`的包装器。 如果llm.invoke(prompt)调用有一个`stop`参数,你应该在这里传递它。 """ llm_string = get_prompts( {**llm.dict(), **{"stop": stop}}, [], )[1] return self.delete(prompt, llm_string=llm_string)
[docs] async def adelete_through_llm( self, prompt: str, llm: LLM, stop: Optional[List[str]] = None ) -> None: """一个围绕`adelete`的包装器,LLM被传递。如果llm.invoke(prompt)调用有一个`stop`参数,你应该在这里传递它。 """ llm_string = ( await aget_prompts( {**llm.dict(), **{"stop": stop}}, [], ) )[1] return await self.adelete(prompt, llm_string=llm_string)
[docs] def delete(self, prompt: str, llm_string: str) -> None: """如果存在条目,则从缓存中驱逐。""" self.astra_env.ensure_db_setup() doc_id = self._make_id(prompt, llm_string) self.collection.delete_one(doc_id)
[docs] async def adelete(self, prompt: str, llm_string: str) -> None: """如果存在条目,则从缓存中驱逐。""" await self.astra_env.aensure_db_setup() doc_id = self._make_id(prompt, llm_string) await self.async_collection.delete_one(doc_id)
[docs] def clear(self, **kwargs: Any) -> None: self.astra_env.ensure_db_setup() self.collection.clear()
[docs] async def aclear(self, **kwargs: Any) -> None: await self.astra_env.aensure_db_setup() await self.async_collection.clear()
ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85 ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache" ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 _unset = ["unset"] class _CachedAwaitable: """将一个可等待对象的结果缓存起来,以便可以多次等待。""" def __init__(self, awaitable: Awaitable[Any]): self.awaitable = awaitable self.result = _unset def __await__(self) -> Generator: if self.result is _unset: self.result = yield from self.awaitable.__await__() return self.result def _reawaitable(func: Callable) -> Callable: """使一个异步函数的结果可以被多次等待。""" @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> _CachedAwaitable: return _CachedAwaitable(func(*args, **kwargs)) return wrapper def _async_lru_cache(maxsize: int = 128, typed: bool = False) -> Callable: """最近最少使用的异步缓存装饰器。 相当于用于异步函数的functools.lru_cache。 """ def decorating_function(user_function: Callable) -> Callable: return lru_cache(maxsize, typed)(_reawaitable(user_function)) return decorating_function
[docs]@deprecated( since="0.0.28", removal="0.3.0", alternative_import="langchain_astradb.AstraDBSemanticCache", ) class AstraDBSemanticCache(BaseCache):
[docs] def __init__( self, *, collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, token: Optional[str] = None, api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, async_astra_db_client: Optional[AsyncAstraDB] = None, namespace: Optional[str] = None, setup_mode: AstraSetupMode = AstraSetupMode.SYNC, pre_delete_collection: bool = False, embedding: Embeddings, metric: Optional[str] = None, similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD, ): """使用Astra DB作为语义(即基于相似性)查找的向量存储后端的缓存。 它使用单个(向量)集合,并可以存储来自多个LLM的缓存值,因此LLM的'llm_string'存储在文档元数据中。 您可以选择首选相似度(或使用API默认值)。 默认分数阈值已调整为默认度量标准。 如果切换到另一个距离度量,请自行小心调整。 参数: collection_name:要创建/使用的Astra DB集合的名称。 token:用于Astra DB使用的API令牌。 api_endpoint:API端点的完整URL, 例如`https://<DB-ID>-us-east1.apps.astra.datastax.com`。 astra_db_client:*令牌+api_endpoint*的替代方法, 您可以传递一个已创建的'astrapy.db.AstraDB'实例。 async_astra_db_client:*令牌+api_endpoint*的替代方法, 您可以传递一个已创建的'astrapy.db.AsyncAstraDB'实例。 namespace:创建集合的命名空间(又名键空间)。 默认为数据库的“默认命名空间”。 setup_mode:用于创建Astra DB集合的模式(SYNC、ASYNC或OFF)。 pre_delete_collection:在创建集合之前是否删除集合。 如果为False且集合已经存在,则将使用现有集合。 embedding:用于语义编码和搜索的嵌入提供程序。 metric:用于评估文本嵌入相似性的函数。 默认为'cosine'(可选:'euclidean'、'dot_product')。 similarity_threshold:接受(语义搜索)匹配的最小相似度。 """ self.embedding = embedding self.metric = metric self.similarity_threshold = similarity_threshold self.collection_name = collection_name # The contract for this class has separate lookup and update: # in order to spare some embedding calculations we cache them between # the two calls. # Note: each instance of this class has its own `_get_embedding` with # its own lru. @lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) def _cache_embedding(text: str) -> List[float]: return self.embedding.embed_query(text=text) self._get_embedding = _cache_embedding @_async_lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) async def _acache_embedding(text: str) -> List[float]: return await self.embedding.aembed_query(text=text) self._aget_embedding = _acache_embedding embedding_dimension: Union[int, Awaitable[int], None] = None if setup_mode == AstraSetupMode.ASYNC: embedding_dimension = self._aget_embedding_dimension() elif setup_mode == AstraSetupMode.SYNC: embedding_dimension = self._get_embedding_dimension() self.astra_env = _AstraDBCollectionEnvironment( collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, namespace=namespace, setup_mode=setup_mode, pre_delete_collection=pre_delete_collection, embedding_dimension=embedding_dimension, metric=metric, ) self.collection = self.astra_env.collection self.async_collection = self.astra_env.async_collection
def _get_embedding_dimension(self) -> int: return len(self._get_embedding(text="This is a sample sentence.")) async def _aget_embedding_dimension(self) -> int: return len(await self._aget_embedding(text="This is a sample sentence.")) @staticmethod def _make_id(prompt: str, llm_string: str) -> str: return f"{_hash(prompt)}#{_hash(llm_string)}"
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: self.astra_env.ensure_db_setup() doc_id = self._make_id(prompt, llm_string) llm_string_hash = _hash(llm_string) embedding_vector = self._get_embedding(text=prompt) body = _dumps_generations(return_val) # self.collection.upsert( { "_id": doc_id, "body_blob": body, "llm_string_hash": llm_string_hash, "$vector": embedding_vector, } )
[docs] async def aupdate( self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE ) -> None: await self.astra_env.aensure_db_setup() doc_id = self._make_id(prompt, llm_string) llm_string_hash = _hash(llm_string) embedding_vector = await self._aget_embedding(text=prompt) body = _dumps_generations(return_val) # await self.async_collection.upsert( { "_id": doc_id, "body_blob": body, "llm_string_hash": llm_string_hash, "$vector": embedding_vector, } )
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: hit_with_id = self.lookup_with_id(prompt, llm_string) if hit_with_id is not None: return hit_with_id[1] else: return None
[docs] async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: hit_with_id = await self.alookup_with_id(prompt, llm_string) if hit_with_id is not None: return hit_with_id[1] else: return None
[docs] def lookup_with_id( self, prompt: str, llm_string: str ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: """根据提示和llm_string进行查找。 如果有匹配结果,则返回顶部匹配结果的(document_id, cached_entry)。 """ self.astra_env.ensure_db_setup() prompt_embedding: List[float] = self._get_embedding(text=prompt) llm_string_hash = _hash(llm_string) hit = self.collection.vector_find_one( vector=prompt_embedding, filter={ "llm_string_hash": llm_string_hash, }, fields=["body_blob", "_id"], include_similarity=True, ) if hit is None or hit["$similarity"] < self.similarity_threshold: return None else: generations = _loads_generations(hit["body_blob"]) if generations is not None: # this protects against malformed cached items: return hit["_id"], generations else: return None
[docs] async def alookup_with_id( self, prompt: str, llm_string: str ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: """根据提示和llm_string进行查找。 如果有匹配结果,则返回顶部匹配结果的(document_id, cached_entry)。 """ await self.astra_env.aensure_db_setup() prompt_embedding: List[float] = await self._aget_embedding(text=prompt) llm_string_hash = _hash(llm_string) hit = await self.async_collection.vector_find_one( vector=prompt_embedding, filter={ "llm_string_hash": llm_string_hash, }, fields=["body_blob", "_id"], include_similarity=True, ) if hit is None or hit["$similarity"] < self.similarity_threshold: return None else: generations = _loads_generations(hit["body_blob"]) if generations is not None: # this protects against malformed cached items: return hit["_id"], generations else: return None
[docs] def lookup_with_id_through_llm( self, prompt: str, llm: LLM, stop: Optional[List[str]] = None ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: llm_string = get_prompts( {**llm.dict(), **{"stop": stop}}, [], )[1] return self.lookup_with_id(prompt, llm_string=llm_string)
[docs] async def alookup_with_id_through_llm( self, prompt: str, llm: LLM, stop: Optional[List[str]] = None ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: llm_string = ( await aget_prompts( {**llm.dict(), **{"stop": stop}}, [], ) )[1] return await self.alookup_with_id(prompt, llm_string=llm_string)
[docs] def delete_by_document_id(self, document_id: str) -> None: """考虑到这是一个“相似性搜索”缓存,一个合理的失效模式是首先进行查找以获取一个ID,然后使用该ID进行删除。这是第二步。 """ self.astra_env.ensure_db_setup() self.collection.delete_one(document_id)
[docs] async def adelete_by_document_id(self, document_id: str) -> None: """考虑到这是一个“相似性搜索”缓存,一个合理的失效模式是首先进行查找以获取一个ID,然后使用该ID进行删除。这是第二步。 """ await self.astra_env.aensure_db_setup() await self.async_collection.delete_one(document_id)
[docs] def clear(self, **kwargs: Any) -> None: self.astra_env.ensure_db_setup() self.collection.clear()
[docs] async def aclear(self, **kwargs: Any) -> None: await self.astra_env.aensure_db_setup() await self.async_collection.clear()
[docs]class AzureCosmosDBSemanticCache(BaseCache): """使用Cosmos DB Mongo vCore向量存储后端的缓存""" DEFAULT_DATABASE_NAME = "CosmosMongoVCoreCacheDB" DEFAULT_COLLECTION_NAME = "CosmosMongoVCoreCacheColl"
[docs] def __init__( self, cosmosdb_connection_string: str, database_name: str, collection_name: str, embedding: Embeddings, *, cosmosdb_client: Optional[Any] = None, num_lists: int = 100, similarity: CosmosDBSimilarityType = CosmosDBSimilarityType.COS, kind: CosmosDBVectorSearchType = CosmosDBVectorSearchType.VECTOR_IVF, dimensions: int = 1536, m: int = 16, ef_construction: int = 64, ef_search: int = 40, score_threshold: Optional[float] = None, application_name: str = "LANGCHAIN_CACHING_PYTHON", ): """Args: cosmosdb_connection_string: Cosmos DB Mongo vCore连接字符串 cosmosdb_client: Cosmos DB Mongo vCore客户端 embedding (Embedding): 用于语义编码和搜索的嵌入提供程序。 database_name: CosmosDBMongoVCoreSemanticCache的数据库名称 collection_name: CosmosDBMongoVCoreSemanticCache的集合名称 num_lists: 这个整数是倒排文件(IVF)索引用来对向量数据进行分组的簇的数量。 我们建议将numLists设置为documentCount/1000,用于最多100万个文档, 并将其设置为sqrt(documentCount),用于超过100万个文档。 使用numLists值为1相当于执行蛮力搜索,性能有限。 dimensions: 向量相似性的维度数。 支持的最大维度数为2000 similarity: 与IVF索引一起使用的相似性度量。 可能的选项有: - CosmosDBSimilarityType.COS(余弦距离), - CosmosDBSimilarityType.L2(欧氏距离),以及 - CosmosDBSimilarityType.IP(内积)。 kind: 要创建的向量索引类型。 可能的选项有: - vector-ivf - vector-hnsw:仅作为预览功能提供, 若要启用,请访问https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/preview-features m: 每层的最大连接数(默认为16,最小值为2,最大值为100)。较高的m适用于具有高维度和/或高准确性要求的数据集。 ef_construction: 用于构建图的动态候选列表的大小(默认为64,最小值为4,最大值为1000)。较高的ef_construction将导致更好的索引质量和更高的准确性,但也会增加构建索引所需的时间。 ef_construction必须至少为2 * m ef_search: 用于搜索的动态候选列表的大小(默认为40)。较高的值提供更好的召回率,但会降低速度。 score_threshold: 用于过滤向量搜索文档的最大分数。 application_name: 用于客户端跟踪和记录的应用程序名称 """ self._validate_enum_value(similarity, CosmosDBSimilarityType) self._validate_enum_value(kind, CosmosDBVectorSearchType) if not cosmosdb_connection_string: raise ValueError(" CosmosDB connection string can be empty.") self.cosmosdb_connection_string = cosmosdb_connection_string self.cosmosdb_client = cosmosdb_client self.embedding = embedding self.database_name = database_name or self.DEFAULT_DATABASE_NAME self.collection_name = collection_name or self.DEFAULT_COLLECTION_NAME self.num_lists = num_lists self.dimensions = dimensions self.similarity = similarity self.kind = kind self.m = m self.ef_construction = ef_construction self.ef_search = ef_search self.score_threshold = score_threshold self._cache_dict: Dict[str, AzureCosmosDBVectorSearch] = {} self.application_name = application_name
def _index_name(self, llm_string: str) -> str: hashed_index = _hash(llm_string) return f"cache:{hashed_index}" def _get_llm_cache(self, llm_string: str) -> AzureCosmosDBVectorSearch: index_name = self._index_name(llm_string) namespace = self.database_name + "." + self.collection_name # return vectorstore client for the specific llm string if index_name in self._cache_dict: return self._cache_dict[index_name] # create new vectorstore client for the specific llm string if self.cosmosdb_client: collection = self.cosmosdb_client[self.database_name][self.collection_name] self._cache_dict[index_name] = AzureCosmosDBVectorSearch( collection=collection, embedding=self.embedding, index_name=index_name, ) else: self._cache_dict[ index_name ] = AzureCosmosDBVectorSearch.from_connection_string( connection_string=self.cosmosdb_connection_string, namespace=namespace, embedding=self.embedding, index_name=index_name, application_name=self.application_name, ) # create index for the vectorstore vectorstore = self._cache_dict[index_name] if not vectorstore.index_exists(): vectorstore.create_index( self.num_lists, self.dimensions, self.similarity, self.kind, self.m, self.ef_construction, ) return vectorstore
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。""" llm_cache = self._get_llm_cache(llm_string) generations: List = [] # Read from a Hash results = llm_cache.similarity_search( query=prompt, k=1, kind=self.kind, ef_search=self.ef_search, score_threshold=self.score_threshold, # type: ignore[arg-type] ) if results: for document in results: try: generations.extend(loads(document.metadata["return_val"])) except Exception: logger.warning( "Retrieving a cache value that could not be deserialized " "properly. This is likely due to the cache being in an " "older format. Please recreate your cache to avoid this " "error." ) # In a previous life we stored the raw text directly # in the table, so assume it's in that format. generations.extend( _load_generations_from_json(document.metadata["return_val"]) ) return generations if generations else None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """根据提示和llm_string更新缓存。""" for gen in return_val: if not isinstance(gen, Generation): raise ValueError( "CosmosDBMongoVCoreSemanticCache only supports caching of " f"normal LLM generations, got {type(gen)}" ) llm_cache = self._get_llm_cache(llm_string) metadata = { "llm_string": llm_string, "prompt": prompt, "return_val": dumps([g for g in return_val]), } llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
[docs] def clear(self, **kwargs: Any) -> None: """清除给定llm_string的语义缓存。""" index_name = self._index_name(kwargs["llm_string"]) if index_name in self._cache_dict: self._cache_dict[index_name].get_collection().delete_many({})
# self._cache_dict[index_name].clear_collection() @staticmethod def _validate_enum_value(value: Any, enum_type: Type[Enum]) -> None: if not isinstance(value, enum_type): raise ValueError(f"Invalid enum value: {value}. Expected {enum_type}.")
[docs]class OpenSearchSemanticCache(BaseCache): """使用OpenSearch向量存储后端的缓存"""
[docs] def __init__( self, opensearch_url: str, embedding: Embeddings, score_threshold: float = 0.2 ): """参数: opensearch_url (str): 连接到OpenSearch的URL。 embedding (Embedding): 用于语义编码和搜索的嵌入提供程序。 score_threshold (float, 0.2): 示例: .. code-block:: python import langchain from langchain.cache import OpenSearchSemanticCache from langchain.embeddings import OpenAIEmbeddings langchain.llm_cache = OpenSearchSemanticCache( opensearch_url="http//localhost:9200", embedding=OpenAIEmbeddings() ) """ self._cache_dict: Dict[str, OpenSearchVectorStore] = {} self.opensearch_url = opensearch_url self.embedding = embedding self.score_threshold = score_threshold
def _index_name(self, llm_string: str) -> str: hashed_index = _hash(llm_string) return f"cache_{hashed_index}" def _get_llm_cache(self, llm_string: str) -> OpenSearchVectorStore: index_name = self._index_name(llm_string) # return vectorstore client for the specific llm string if index_name in self._cache_dict: return self._cache_dict[index_name] # create new vectorstore client for the specific llm string self._cache_dict[index_name] = OpenSearchVectorStore( opensearch_url=self.opensearch_url, index_name=index_name, embedding_function=self.embedding, ) # create index for the vectorstore vectorstore = self._cache_dict[index_name] if not vectorstore.index_exists(): _embedding = self.embedding.embed_query(text="test") vectorstore.create_index(len(_embedding), index_name) return vectorstore
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """根据提示和llm_string进行查找。""" llm_cache = self._get_llm_cache(llm_string) generations: List = [] # Read from a Hash results = llm_cache.similarity_search( query=prompt, k=1, score_threshold=self.score_threshold, ) if results: for document in results: try: generations.extend(loads(document.metadata["return_val"])) except Exception: logger.warning( "Retrieving a cache value that could not be deserialized " "properly. This is likely due to the cache being in an " "older format. Please recreate your cache to avoid this " "error." ) generations.extend( _load_generations_from_json(document.metadata["return_val"]) ) return generations if generations else None
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """根据提示和llm_string更新缓存。""" for gen in return_val: if not isinstance(gen, Generation): raise ValueError( "OpenSearchSemanticCache only supports caching of " f"normal LLM generations, got {type(gen)}" ) llm_cache = self._get_llm_cache(llm_string) metadata = { "llm_string": llm_string, "prompt": prompt, "return_val": dumps([g for g in return_val]), } llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
[docs] def clear(self, **kwargs: Any) -> None: """清除给定llm_string的语义缓存。""" index_name = self._index_name(kwargs["llm_string"]) if index_name in self._cache_dict: self._cache_dict[index_name].delete_index(index_name=index_name) del self._cache_dict[index_name]