Source code for langchain_community.llms.ctranslate2
from typing import Any, Dict, List, Optional, Union
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator
[docs]class CTranslate2(BaseLLM):
"""CTranslate2语言模型。"""
model_path: str = ""
"""CTranslate2模型目录的路径。"""
tokenizer_name: str = ""
"""需要加载正确分词器的原始Hugging Face模型的名称。"""
device: str = "cpu"
"""要使用的设备(可能的取值有:cpu、cuda、auto)。"""
device_index: Union[int, List[int]] = 0
"""需要将此生成器放置在的设备ID。"""
compute_type: Union[str, Dict[str, str]] = "default"
"""模型计算类型或将设备名称映射到计算类型的字典
(可能的值包括:default,默认值,auto,int8,int8_float32,int8_float16,
int8_bfloat16,int16,float16,bfloat16,float32)。"""
max_length: int = 512
"""最大生成长度。"""
sampling_topk: int = 1
"""从前K个候选项中随机抽取预测结果。"""
sampling_topp: float = 1
"""保留最有可能的令牌,其累积概率超过此值。"""
sampling_temperature: float = 1
"""对温度进行抽样以生成更多的随机样本。"""
client: Any #: :meta private:
tokenizer: Any #: :meta private:
ctranslate2_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""保存所有适用于`ctranslate2.Generator`调用但未明确指定的模型参数。"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证Python包是否存在于环境中。"""
try:
import ctranslate2
except ImportError:
raise ImportError(
"Could not import ctranslate2 python package. "
"Please install it with `pip install ctranslate2`."
)
try:
import transformers
except ImportError:
raise ImportError(
"Could not import transformers python package. "
"Please install it with `pip install transformers`."
)
values["client"] = ctranslate2.Generator(
model_path=values["model_path"],
device=values["device"],
device_index=values["device_index"],
compute_type=values["compute_type"],
**values["ctranslate2_kwargs"],
)
values["tokenizer"] = transformers.AutoTokenizer.from_pretrained(
values["tokenizer_name"]
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""获取默认参数。"""
return {
"max_length": self.max_length,
"sampling_topk": self.sampling_topk,
"sampling_topp": self.sampling_topp,
"sampling_temperature": self.sampling_temperature,
}
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
# build sampling parameters
params = {**self._default_params, **kwargs}
# call the model
encoded_prompts = self.tokenizer(prompts)["input_ids"]
tokenized_prompts = [
self.tokenizer.convert_ids_to_tokens(encoded_prompt)
for encoded_prompt in encoded_prompts
]
results = self.client.generate_batch(tokenized_prompts, **params)
sequences = [result.sequences_ids[0] for result in results]
decoded_sequences = [self.tokenizer.decode(seq) for seq in sequences]
generations = []
for text in decoded_sequences:
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "ctranslate2"