Source code for langchain.memory.vectorstore

"""用于支持VectorStore的内存对象的类。"""

from typing import Any, Dict, List, Optional, Sequence, Union

from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.vectorstores import VectorStoreRetriever

from langchain.memory.chat_memory import BaseMemory
from langchain.memory.utils import get_prompt_input_key


[docs]class VectorStoreRetrieverMemory(BaseMemory): """基于VectorStoreRetriever的内存。""" retriever: VectorStoreRetriever = Field(exclude=True) """用于连接的VectorStoreRetriever对象。""" memory_key: str = "history" #: :meta private: """用于在load_memory_variables结果中定位内存的关键名称。""" input_key: Optional[str] = None """用于索引要加载到内存变量的输入的键名。""" return_docs: bool = False """是否直接返回查询数据库的结果。""" exclude_input_keys: Sequence[str] = Field(default_factory=tuple) """构建文档时要排除的输入键,除了内存键外还有哪些键。""" @property def memory_variables(self) -> List[str]: """从load_memory_variables方法中发出的密钥列表。""" return [self.memory_key] def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str: """获取提示的输入键。""" if self.input_key is None: return get_prompt_input_key(inputs, self.memory_variables) return self.input_key def _documents_to_memory_variables( self, docs: List[Document] ) -> Dict[str, Union[List[Document], str]]: result: Union[List[Document], str] if not self.return_docs: result = "\n".join([doc.page_content for doc in docs]) else: result = docs return {self.memory_key: result}
[docs] def load_memory_variables( self, inputs: Dict[str, Any] ) -> Dict[str, Union[List[Document], str]]: """返回历史缓冲区。""" input_key = self._get_prompt_input_key(inputs) query = inputs[input_key] docs = self.retriever.invoke(query) return self._documents_to_memory_variables(docs)
[docs] async def aload_memory_variables( self, inputs: Dict[str, Any] ) -> Dict[str, Union[List[Document], str]]: """返回历史缓冲区。""" input_key = self._get_prompt_input_key(inputs) query = inputs[input_key] docs = await self.retriever.ainvoke(query) return self._documents_to_memory_variables(docs)
def _form_documents( self, inputs: Dict[str, Any], outputs: Dict[str, str] ) -> List[Document]: """将此对话的格式上下文写入缓冲区。""" # Each document should only include the current turn, not the chat history exclude = set(self.exclude_input_keys) exclude.add(self.memory_key) filtered_inputs = {k: v for k, v in inputs.items() if k not in exclude} texts = [ f"{k}: {v}" for k, v in list(filtered_inputs.items()) + list(outputs.items()) ] page_content = "\n".join(texts) return [Document(page_content=page_content)]
[docs] def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """将此对话的上下文保存到缓冲区中。""" documents = self._form_documents(inputs, outputs) self.retriever.add_documents(documents)
[docs] async def asave_context( self, inputs: Dict[str, Any], outputs: Dict[str, str] ) -> None: """将此对话的上下文保存到缓冲区中。""" documents = self._form_documents(inputs, outputs) await self.retriever.aadd_documents(documents)
[docs] def clear(self) -> None: """无需清除。"""
[docs] async def aclear(self) -> None: """无需清除。"""