Source code for langchain_experimental.tabular_synthetic_data.base

import asyncio
from typing import Any, Dict, List, Optional, Union

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.few_shot import FewShotPromptTemplate


[docs]class SyntheticDataGenerator(BaseModel): """使用给定的LLM和few-shot模板生成合成数据。 利用提供的LLM根据few-shot提示模板生成合成数据。 属性: template (FewShotPromptTemplate): few-shot提示的模板。 llm (Optional[BaseLanguageModel]): 用于生成的大型语言模型。 llm_chain (Optional[Chain]): 具有LLM和few-shot模板的LLM链。 example_input_key (str): 用于存储示例输入的键。 用法示例: >>> template = FewShotPromptTemplate(...) >>> llm = BaseLanguageModel(...) >>> generator = SyntheticDataGenerator(template=template, llm=llm) >>> results = generator.generate(subject="climate change", runs=5) """ template: FewShotPromptTemplate llm: Optional[BaseLanguageModel] = None results: list = [] llm_chain: Optional[Chain] = None example_input_key: str = "example" class Config: validate_assignment = True @root_validator(pre=False, skip_on_failure=True) def set_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: llm_chain = values.get("llm_chain") llm = values.get("llm") few_shot_template = values.get("template") if not llm_chain: # If llm_chain is None or not present if llm is None or few_shot_template is None: raise ValueError( "Both llm and few_shot_template must be provided if llm_chain is " "not given." ) values["llm_chain"] = LLMChain(llm=llm, prompt=few_shot_template) return values @staticmethod def _format_dict_to_string(input_dict: Dict) -> str: formatted_str = ", ".join( [f"{key}: {value}" for key, value in input_dict.items()] ) return formatted_str def _update_examples(self, example: Union[BaseModel, Dict[str, Any], str]) -> None: """通过将先前生成的示例添加到少样本列表中,防止重复。""" if self.template and self.template.examples: if isinstance(example, BaseModel): formatted_example = self._format_dict_to_string(example.dict()) elif isinstance(example, dict): formatted_example = self._format_dict_to_string(example) else: formatted_example = str(example) self.template.examples.pop(0) self.template.examples.append({self.example_input_key: formatted_example})
[docs] def generate(self, subject: str, runs: int, *args: Any, **kwargs: Any) -> List[str]: """使用给定的主题字符串生成合成数据。 参数: subject (str): 合成数据的主题。 runs (int): 生成数据的次数。 extra (str): 用于数据生成中指导的额外指令。 返回: List[str]: 生成的合成数据列表。 使用示例: >>> results = generator.generate(subject="climate change", runs=5, extra="Focus on environmental impacts.") """ if self.llm_chain is None: raise ValueError( "llm_chain is none, either set either llm_chain or llm at generator " "construction" ) for _ in range(runs): result = self.llm_chain.run(subject=subject, *args, **kwargs) self.results.append(result) self._update_examples(result) return self.results
[docs] async def agenerate( self, subject: str, runs: int, extra: str = "", *args: Any, **kwargs: Any ) -> List[str]: """使用给定的主题异步生成合成数据。 注意:由于LLM并发调用,通过向“extra”关键字参数添加特定指令,可以减少重复。 Args: subject (str): 将生成合成数据的主题。 runs (int): 异步生成数据的次数。 extra (str): 数据生成中的额外指令,用于调整数据生成的方向。 Returns: List[str]: 给定主题生成的合成数据列表。 Usage Example: >>> results = await generator.agenerate(subject="climate change", runs=5, extra="Focus on env impacts.") """ async def run_chain( subject: str, extra: str = "", *args: Any, **kwargs: Any ) -> None: if self.llm_chain is not None: result = await self.llm_chain.arun( subject=subject, extra=extra, *args, **kwargs ) self.results.append(result) await asyncio.gather( *(run_chain(subject=subject, extra=extra) for _ in range(runs)) ) return self.results