from __future__ import annotations
from abc import ABC
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_community.llms.utils import enforce_stop_tokens
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
VALID_PROVIDERS = ("cohere", "meta")
[docs]class OCIAuthType(Enum):
"""OCI身份验证类型作为枚举器。"""
API_KEY = 1
SECURITY_TOKEN = 2
INSTANCE_PRINCIPAL = 3
RESOURCE_PRINCIPAL = 4
[docs]class OCIGenAIBase(BaseModel, ABC):
"""OCI GenAI模型的基类"""
client: Any #: :meta private:
auth_type: Optional[str] = "API_KEY"
"""身份验证类型,可以是
API_KEY,
SECURITY_TOKEN,
INSTANCE_PRINCIPLE,
RESOURCE_PRINCIPLE
如果未指定,将使用API_KEY。"""
auth_profile: Optional[str] = "DEFAULT"
"""在~/.oci/config中的配置文件名称
如果未指定,则将使用DEFAULT"""
model_id: str = None # type: ignore[assignment]
"""要调用的模型的ID,例如cohere.command"""
provider: str = None # type: ignore[assignment]
"""模型的提供者名称。默认为None,将尝试从模型ID中推导,否则需要用户输入。"""
model_kwargs: Optional[Dict] = None
"""传递给模型的关键字参数"""
service_endpoint: str = None # type: ignore[assignment]
"""服务端点URL"""
compartment_id: str = None # type: ignore[assignment]
"""区段的OCID"""
is_stream: bool = False
"""是否流式传输部分进度"""
llm_stop_sequence_mapping: Mapping[str, str] = {
"cohere": "stop_sequences",
"meta": "stop",
}
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证OCI配置和Python包是否存在于环境中。"""
# Skip creating new client if passed in constructor
if values["client"] is not None:
return values
try:
import oci
client_kwargs = {
"config": {},
"signer": None,
"service_endpoint": values["service_endpoint"],
"retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY,
"timeout": (10, 240), # default timeout config for OCI Gen AI service
}
if values["auth_type"] == OCIAuthType(1).name:
client_kwargs["config"] = oci.config.from_file(
profile_name=values["auth_profile"]
)
client_kwargs.pop("signer", None)
elif values["auth_type"] == OCIAuthType(2).name:
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None
)
with open(
oci_config.get("security_token_file"), encoding="utf-8"
) as f:
st_string = f.read()
return oci.auth.signers.SecurityTokenSigner(st_string, pk)
client_kwargs["config"] = oci.config.from_file(
profile_name=values["auth_profile"]
)
client_kwargs["signer"] = make_security_token_signer(
oci_config=client_kwargs["config"]
)
elif values["auth_type"] == OCIAuthType(3).name:
client_kwargs[
"signer"
] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
elif values["auth_type"] == OCIAuthType(4).name:
client_kwargs[
"signer"
] = oci.auth.signers.get_resource_principals_signer()
else:
raise ValueError("Please provide valid value to auth_type")
values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
**client_kwargs
)
except ImportError as ex:
raise ImportError(
"Could not import oci python package. "
"Please make sure you have the oci package installed."
) from ex
except Exception as e:
raise ValueError(
"Could not authenticate with OCI client. "
"Please check if ~/.oci/config exists. "
"If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, "
"Please check the specified "
"auth_profile and auth_type are valid."
) from e
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
}
def _get_provider(self) -> str:
if self.provider is not None:
provider = self.provider
else:
provider = self.model_id.split(".")[0].lower()
if provider not in VALID_PROVIDERS:
raise ValueError(
f"Invalid provider derived from model_id: {self.model_id} "
"Please explicitly pass in the supported provider "
"when using custom endpoint"
)
return provider
[docs]class OCIGenAI(LLM, OCIGenAIBase):
"""OCI大型语言模型。
要进行身份验证,OCI客户端使用https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm中描述的方法。
身份验证方法通过auth_type传递,应为以下之一:
API_KEY(默认),SECURITY_TOKEN,INSTANCE_PRINCIPLE,RESOURCE_PRINCIPLE
确保您具有访问OCI生成式AI服务所需的策略(配置文件/角色)。
如果使用特定的配置文件配置文件,则必须通过auth_profile传递配置文件的名称(来自~/.oci/config)。
要使用,必须向构造函数提供区段ID、终端点URL和模型ID作为命名参数。
示例:
.. code-block:: python
from langchain_community.llms import OCIGenAI
llm = OCIGenAI(
model_id="MY_MODEL_ID",
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
compartment_id="MY_OCID"
)
"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "oci"
def _prepare_invocation_object(
self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
from oci.generative_ai_inference import models
oci_llm_request_mapping = {
"cohere": models.CohereLlmInferenceRequest,
"meta": models.LlamaLlmInferenceRequest,
}
provider = self._get_provider()
_model_kwargs = self.model_kwargs or {}
if stop is not None:
_model_kwargs[self.llm_stop_sequence_mapping[provider]] = stop
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
else:
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
inference_params = {**_model_kwargs, **kwargs}
inference_params["prompt"] = prompt
inference_params["is_stream"] = self.is_stream
invocation_obj = models.GenerateTextDetails(
compartment_id=self.compartment_id,
serving_mode=serving_mode,
inference_request=oci_llm_request_mapping[provider](**inference_params),
)
return invocation_obj
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
provider = self._get_provider()
if provider == "cohere":
text = response.data.inference_response.generated_texts[0].text
elif provider == "meta":
text = response.data.inference_response.choices[0].text
else:
raise ValueError(f"Invalid provider: {provider}")
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用OCIGenAI生成端点。
参数:
prompt: 传递给模型的提示。
stop: 生成时可选的停止词列表。
返回:
模型生成的字符串。
示例:
.. code-block:: python
response = llm.invoke("Tell me a joke.")
"""
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
response = self.client.generate_text(invocation_obj)
return self._process_response(response, stop)