from __future__ import annotations
import logging
import os
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_experimental.rl_chain.metrics import (
MetricsTrackerAverage,
MetricsTrackerRollingWindow,
)
from langchain_experimental.rl_chain.model_repository import ModelRepository
from langchain_experimental.rl_chain.vw_logger import VwLogger
if TYPE_CHECKING:
import vowpal_wabbit_next as vw
logger = logging.getLogger(__name__)
class _BasedOn:
def __init__(self, value: Any):
self.value = value
def __str__(self) -> str:
return str(self.value)
__repr__ = __str__
[docs]def BasedOn(anything: Any) -> _BasedOn:
"""将一个值包装起来,表示它应该是基于某种条件的。"""
return _BasedOn(anything)
class _ToSelectFrom:
def __init__(self, value: Any):
self.value = value
def __str__(self) -> str:
return str(self.value)
__repr__ = __str__
[docs]def ToSelectFrom(anything: Any) -> _ToSelectFrom:
"""将一个值包装起来,以指示它应该被选择。"""
if not isinstance(anything, list):
raise ValueError("ToSelectFrom must be a list to select from")
return _ToSelectFrom(anything)
class _Embed:
def __init__(self, value: Any, keep: bool = False):
self.value = value
self.keep = keep
def __str__(self) -> str:
return str(self.value)
__repr__ = __str__
def Embed(anything: Any, keep: bool = False) -> Any:
"""将一个值包装起来,以指示它应该被嵌入。"""
if isinstance(anything, _ToSelectFrom):
return ToSelectFrom(Embed(anything.value, keep=keep))
elif isinstance(anything, _BasedOn):
return BasedOn(Embed(anything.value, keep=keep))
if isinstance(anything, list):
return [Embed(v, keep=keep) for v in anything]
elif isinstance(anything, dict):
return {k: Embed(v, keep=keep) for k, v in anything.items()}
elif isinstance(anything, _Embed):
return anything
return _Embed(anything, keep=keep)
[docs]def EmbedAndKeep(anything: Any) -> Any:
"""将一个值包装起来,以指示它应该被嵌入和保留。"""
return Embed(anything, keep=True)
# 辅助函数
[docs]def stringify_embedding(embedding: List) -> str:
"""将嵌入转换为字符串。"""
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
[docs]def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
"""将输入字符串解析为示例列表。"""
return [parser.parse_line(line) for line in input_str.split("\n")]
[docs]def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
"""从输入中获取BasedOn和ToSelectFrom。"""
to_select_from = {
k: inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _ToSelectFrom)
}
if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
)
based_on = {
k: inputs[k].value if isinstance(inputs[k].value, list) else [inputs[k].value]
for k in inputs.keys()
if isinstance(inputs[k], _BasedOn)
}
return based_on, to_select_from
# 结束辅助函数
[docs]class Selected(ABC):
"""用于表示所选项目的抽象类。"""
pass
TSelected = TypeVar("TSelected", bound=Selected)
[docs]class Event(Generic[TSelected], ABC):
"""表示事件的抽象类。"""
inputs: Dict[str, Any]
selected: Optional[TSelected]
[docs] def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None):
self.inputs = inputs
self.selected = selected
TEvent = TypeVar("TEvent", bound=Event)
[docs]class Policy(Generic[TEvent], ABC):
"""表示策略的抽象类。"""
[docs] def __init__(self, **kwargs: Any):
pass
[docs] @abstractmethod
def predict(self, event: TEvent) -> Any:
...
[docs] @abstractmethod
def learn(self, event: TEvent) -> None:
...
[docs] @abstractmethod
def log(self, event: TEvent) -> None:
...
[docs] def save(self) -> None:
pass
[docs]class VwPolicy(Policy):
"""Vowpal Wabbit策略。"""
[docs] def __init__(
self,
model_repo: ModelRepository,
vw_cmd: List[str],
feature_embedder: Embedder,
vw_logger: VwLogger,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.model_repo = model_repo
self.workspace = self.model_repo.load(vw_cmd)
self.feature_embedder = feature_embedder
self.vw_logger = vw_logger
[docs] def predict(self, event: TEvent) -> Any:
import vowpal_wabbit_next as vw
text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one(
parse_lines(text_parser, self.feature_embedder.format(event))
)
[docs] def learn(self, event: TEvent) -> None:
import vowpal_wabbit_next as vw
vw_ex = self.feature_embedder.format(event)
text_parser = vw.TextFormatParser(self.workspace)
multi_ex = parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)
[docs] def log(self, event: TEvent) -> None:
if self.vw_logger.logging_enabled():
vw_ex = self.feature_embedder.format(event)
self.vw_logger.log(vw_ex)
[docs] def save(self) -> None:
self.model_repo.save(self.workspace)
[docs]class Embedder(Generic[TEvent], ABC):
"""用于表示嵌入器的抽象类。"""
[docs] def __init__(self, *args: Any, **kwargs: Any):
pass
[docs]class SelectionScorer(Generic[TEvent], ABC, BaseModel):
"""用于对所选选择或llm的响应进行评分的抽象类。"""
[docs] @abstractmethod
def score_response(
self, inputs: Dict[str, Any], llm_response: str, event: TEvent
) -> float:
...
[docs]class AutoSelectionScorer(SelectionScorer[Event], BaseModel):
"""自动选择评分器。"""
llm_chain: LLMChain
prompt: Union[BasePromptTemplate, None] = None
scoring_criteria_template_str: Optional[str] = None
[docs] @staticmethod
def get_default_system_prompt() -> SystemMessagePromptTemplate:
return SystemMessagePromptTemplate.from_template(
"PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n \
You are a strict judge that is called on to rank a response based on \
given criteria. You must respond with your ranking by providing a \
single float within the range [0, 1], 0 being very bad \
response and 1 being very good response."
)
[docs] @staticmethod
def get_default_prompt() -> ChatPromptTemplate:
human_template = 'Given this based_on "{rl_chain_selected_based_on}" \
as the most important attribute, rank how good or bad this text is: \
"{rl_chain_selected}".'
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
chat_prompt = ChatPromptTemplate.from_messages(
[default_system_prompt, human_message_prompt]
)
return chat_prompt
@root_validator(pre=True)
def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
llm = values.get("llm")
prompt = values.get("prompt")
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
if prompt is None and scoring_criteria_template_str is None:
prompt = AutoSelectionScorer.get_default_prompt()
elif prompt is None and scoring_criteria_template_str is not None:
human_message_prompt = HumanMessagePromptTemplate.from_template(
scoring_criteria_template_str
)
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
prompt = ChatPromptTemplate.from_messages(
[default_system_prompt, human_message_prompt]
)
values["prompt"] = prompt
values["llm_chain"] = LLMChain(llm=llm, prompt=prompt)
return values
[docs] def score_response(
self, inputs: Dict[str, Any], llm_response: str, event: Event
) -> float:
ranking = self.llm_chain.predict(llm_response=llm_response, **inputs)
ranking = ranking.strip()
try:
resp = float(ranking)
return resp
except Exception as e:
raise RuntimeError(
f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}" # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
)
[docs]class RLChain(Chain, Generic[TEvent]):
"""利用Vowpal Wabbit(VW)模型作为强化学习的学习策略的链条。
属性:
- llm_chain (Chain): 表示基础语言模型链条。
- prompt (BasePromptTemplate): 基础提示的模板。
- selection_scorer (Union[SelectionScorer, None]): 选择评分器。可以设置为None。
- policy (Optional[Policy]): 链条用于学习填充动态提示的策略。
- auto_embed (bool): 确定是否自动嵌入。默认为False。
- metrics (Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]]): 指标跟踪器,可以设置为None。
初始化属性:
- feature_embedder (Embedder): 用于“BasedOn”和“ToSelectFrom”输入的嵌入器。
- model_save_dir (str, optional): 保存VW模型的目录。默认为当前目录。
- reset_model (bool): 如果设置为True,则模型将从头开始训练。默认为False。
- vw_cmd (List[str], optional): VW模型的命令行参数。
- policy (Type[VwPolicy]): 链条使用的策略。
- vw_logs (Optional[Union[str, os.PathLike]]): VW日志的路径。
- metrics_step (int): 指标跟踪器的步长。默认为-1。如果设置而不设置metrics_window_size,则将跟踪平均指标,否则将跟踪滚动窗口指标。
- metrics_window_size (int): 指标跟踪器的窗口大小。默认为-1。如果设置,将跟踪滚动窗口指标。
注意:
该类使用提供的参数初始化VW模型。如果未提供`selection_scorer`,则会记录警告,指示除非调用`update_with_delayed_score`方法,否则不会发生强化学习。""" # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
class _NoOpPolicy(Policy):
"""什么也不做的占位符策略"""
def predict(self, event: TEvent) -> Any:
return None
def learn(self, event: TEvent) -> None:
pass
def log(self, event: TEvent) -> None:
pass
llm_chain: Chain
output_key: str = "result" #: :元数据 私有:
prompt: BasePromptTemplate
selection_scorer: Union[SelectionScorer, None]
active_policy: Policy = _NoOpPolicy()
auto_embed: bool = False
selection_scorer_activated: bool = True
selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on"
metrics: Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]] = None
def __init__(
self,
feature_embedder: Embedder,
model_save_dir: str = "./",
reset_model: bool = False,
vw_cmd: Optional[List[str]] = None,
policy: Type[Policy] = VwPolicy,
vw_logs: Optional[Union[str, os.PathLike]] = None,
metrics_step: int = -1,
metrics_window_size: int = -1,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
if self.selection_scorer is None:
logger.warning(
"No selection scorer provided, which means that no \
reinforcement learning will be done in the RL chain \
unless update_with_delayed_score is called."
)
if isinstance(self.active_policy, RLChain._NoOpPolicy):
self.active_policy = policy(
model_repo=ModelRepository(
model_save_dir, with_history=True, reset=reset_model
),
vw_cmd=vw_cmd or [],
feature_embedder=feature_embedder,
vw_logger=VwLogger(vw_logs),
)
if metrics_window_size > 0:
self.metrics = MetricsTrackerRollingWindow(
step=metrics_step, window_size=metrics_window_size
)
else:
self.metrics = MetricsTrackerAverage(step=metrics_step)
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""期望输入键。
:meta private:
"""
return []
@property
def output_keys(self) -> List[str]:
"""期望的输出键。
:meta private:
"""
return [self.output_key]
[docs] def update_with_delayed_score(
self, score: float, chain_response: Dict[str, Any], force_score: bool = False
) -> None:
"""
使用提供的分数更新学习策略。
如果设置了selection_scorer,并且在方法调用时未提供force_score=True,将会引发错误。
""" # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
if self._can_use_selection_scorer() and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function." # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
)
if self.metrics:
self.metrics.on_feedback(score)
event: TEvent = chain_response["selection_metadata"]
self._call_after_scoring_before_learning(event=event, score=score)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)
[docs] def deactivate_selection_scorer(self) -> None:
"""
停用选择评分器,这意味着链不再尝试使用选择评分器对响应进行评分。
""" # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
self.selection_scorer_activated = False
[docs] def activate_selection_scorer(self) -> None:
"""
激活选择评分器,这意味着该链将尝试使用选择评分器对响应进行评分。
""" # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
self.selection_scorer_activated = True
[docs] def save_progress(self) -> None:
"""
应该调用此函数来保存学习策略模型的状态。
""" # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
self.active_policy.save()
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
super()._validate_inputs(inputs)
if (
self.selected_input_key in inputs.keys()
or self.selected_based_on_input_key in inputs.keys()
):
raise ValueError(
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward." # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
)
def _can_use_selection_scorer(self) -> bool:
"""
返回链是否可以使用选择评分器来对响应进行评分。
""" # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
return self.selection_scorer is not None and self.selection_scorer_activated
@abstractmethod
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
...
@abstractmethod
def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: TEvent, prediction: Any
) -> Tuple[Dict[str, Any], TEvent]:
...
@abstractmethod
def _call_after_llm_before_scoring(
self, llm_response: str, event: TEvent
) -> Tuple[Dict[str, Any], TEvent]:
...
@abstractmethod
def _call_after_scoring_before_learning(
self, event: TEvent, score: Optional[float]
) -> TEvent:
...
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
event: TEvent = self._call_before_predict(inputs=inputs)
prediction = self.active_policy.predict(event=event)
if self.metrics:
self.metrics.on_decision()
next_chain_inputs, event = self._call_after_predict_before_llm(
inputs=inputs, event=event, prediction=prediction
)
t = self.llm_chain.run(**next_chain_inputs, callbacks=_run_manager.get_child())
_run_manager.on_text(t, color="green", verbose=self.verbose)
t = t.strip()
if self.verbose:
_run_manager.on_text("\nCode: ", verbose=self.verbose)
output = t
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
next_chain_inputs, event = self._call_after_llm_before_scoring(
llm_response=output, event=event
)
score = None
try:
if self._can_use_selection_scorer():
score = self.selection_scorer.score_response( # 类型:忽略
inputs=next_chain_inputs, llm_response=output, event=event
)
except Exception as e:
logger.info(
f"The selection scorer was not able to score, \
and the chain was not able to adjust to this response, error: {e}"
)
if self.metrics and score is not None:
self.metrics.on_feedback(score)
event = self._call_after_scoring_before_learning(score=score, event=event)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)
return {self.output_key: {"response": output, "selection_metadata": event}}
@property
def _chain_type(self) -> str:
return "llm_personalizer_chain"
[docs]def is_stringtype_instance(item: Any) -> bool:
"""检查一个项目是否为字符串。"""
return isinstance(item, str) or (
isinstance(item, _Embed) and isinstance(item.value, str)
)
[docs]def embed_string_type(
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, Union[str, List[str]]]:
"""嵌入一个字符串或一个_Embed对象。"""
keep_str = ""
if isinstance(item, _Embed):
encoded = stringify_embedding(model.encode(item.value))
if item.keep:
keep_str = item.value.replace(" ", "_") + " "
elif isinstance(item, str):
encoded = item.replace(" ", "_")
else:
raise ValueError(f"Unsupported type {type(item)} for embedding")
if namespace is None:
raise ValueError(
"The default namespace must be provided when embedding a string or _Embed object." # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
)
return {namespace: keep_str + encoded}
[docs]def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
"""嵌入一个字典项。"""
inner_dict: Dict = {}
for ns, embed_item in item.items():
if isinstance(embed_item, list):
inner_dict[ns] = []
for embed_list_item in embed_item:
embedded = embed_string_type(embed_list_item, model, ns)
inner_dict[ns].append(embedded[ns])
else:
inner_dict.update(embed_string_type(embed_item, model, ns))
return inner_dict
[docs]def embed_list_type(
item: list, model: Any, namespace: Optional[str] = None
) -> List[Dict[str, Union[str, List[str]]]]:
"""嵌入一个列表项。"""
ret_list: List = []
for embed_item in item:
if isinstance(embed_item, dict):
ret_list.append(embed_dict_type(embed_item, model))
elif isinstance(embed_item, list):
item_embedding = embed_list_type(embed_item, model, namespace)
# 从第一个字典中获取第一个键
first_key = next(iter(item_embedding[0]))
# 将该键下的数值进行分组
grouping = {first_key: [item[first_key] for item in item_embedding]}
ret_list.append(grouping)
else:
ret_list.append(embed_string_type(embed_item, model, namespace))
return ret_list
[docs]def embed(
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
model: Any,
namespace: Optional[str] = None,
) -> List[Dict[str, Union[str, List[str]]]]:
"""
使用SentenceTransformer模型(或具有`encode`函数的模型)嵌入动作或上下文。
属性:
to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) 要嵌入的文本,可以是字符串、字符串列表、字典或字典列表。
namespace: (str, optional) 在未提供字典或字典列表时要使用的默认命名空间。
model: (Any, required) 用于嵌入的模型
返回:
List[Dict[str, str]]: 一个字典列表,每个字典的键是命名空间,值是嵌入的字符串
""" # noqa: E501 表示忽略 PEP 8 中的 E501 错误,即行长度超过 79 个字符的错误。 表示忽略 PEP 8 中的行长度限制。 表示忽略PEP 8规范中的每行字符数限制。 表示忽略超过每行字符数限制的警告。 表示忽略PEP 8规范中的行长度限制。 不检查行长度限制 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制,允许当前行超过最大长度限制。 表示忽略 PEP 8 中的行长度限制。 表示忽略超过每行字符限制的警告。 表示忽略超过最大行长度的警告。
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
to_embed, str
):
return [embed_string_type(to_embed, model, namespace)]
elif isinstance(to_embed, dict):
return [embed_dict_type(to_embed, model)]
elif isinstance(to_embed, list):
return embed_list_type(to_embed, model, namespace)
else:
raise ValueError("Invalid input format for embedding")