Source code for langchain_community.memory.kg
from typing import Any, Dict, List, Type, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.graphs import NetworkxEntityGraph
from langchain_community.graphs.networkx_graph import (
KnowledgeTriple,
get_entities,
parse_triples,
)
try:
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)
from langchain.memory.utils import get_prompt_input_key
class ConversationKGMemory(BaseChatMemory):
"""知识图谈话记忆。
与外部知识图集成,用于存储和检索有关对话中知识三元组的信息。"""
k: int = 2
human_prefix: str = "Human"
ai_prefix: str = "AI"
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
knowledge_extraction_prompt: BasePromptTemplate = (
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
)
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
llm: BaseLanguageModel
summary_message_cls: Type[BaseMessage] = SystemMessage
"""在上下文中包括的先前话语数量。"""
memory_key: str = "history" #: :meta private:
[docs] def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""返回历史缓冲区。"""
entities = self._get_current_entities(inputs)
summary_strings = []
for entity in entities:
knowledge = self.kg.get_entity_knowledge(entity)
if knowledge:
summary = f"On {entity}: {'. '.join(knowledge)}."
summary_strings.append(summary)
context: Union[str, List]
if not summary_strings:
context = [] if self.return_messages else ""
elif self.return_messages:
context = [
self.summary_message_cls(content=text) for text in summary_strings
]
else:
context = "\n".join(summary_strings)
return {self.memory_key: context}
@property
def memory_variables(self) -> List[str]:
"""将始终返回内存变量列表。
:元数据 私有:
"""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""获取提示的输入键。"""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
"""获取提示的输出键。"""
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
return list(outputs.keys())[0]
return self.output_key
[docs] def get_current_entities(self, input_string: str) -> List[str]:
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
)
return get_entities(output)
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
"""获取当前对话中的实体。"""
prompt_input_key = self._get_prompt_input_key(inputs)
return self.get_current_entities(inputs[prompt_input_key])
[docs] def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
verbose=True,
)
knowledge = parse_triples(output)
return knowledge
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
"""从对话历史中获取并更新知识图谱。"""
prompt_input_key = self._get_prompt_input_key(inputs)
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
for triple in knowledge:
self.kg.add_triple(triple)
[docs] def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""将此对话中的上下文保存到缓冲区中。"""
super().save_context(inputs, outputs)
self._get_and_update_kg(inputs)
[docs] def clear(self) -> None:
"""清除内存内容。"""
super().clear()
self.kg.clear()
except ImportError:
# Placeholder object
[docs] class ConversationKGMemory: # type: ignore[no-redef]
pass