Source code for langchain_community.llms.oci_generative_ai

from __future__ import annotations

from abc import ABC
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator

from langchain_community.llms.utils import enforce_stop_tokens

CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
VALID_PROVIDERS = ("cohere", "meta")


[docs]class OCIAuthType(Enum): """OCI身份验证类型作为枚举器。""" API_KEY = 1 SECURITY_TOKEN = 2 INSTANCE_PRINCIPAL = 3 RESOURCE_PRINCIPAL = 4
[docs]class OCIGenAIBase(BaseModel, ABC): """OCI GenAI模型的基类""" client: Any #: :meta private: auth_type: Optional[str] = "API_KEY" """身份验证类型,可以是 API_KEY, SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE 如果未指定,将使用API_KEY。""" auth_profile: Optional[str] = "DEFAULT" """在~/.oci/config中的配置文件名称 如果未指定,则将使用DEFAULT""" model_id: str = None # type: ignore[assignment] """要调用的模型的ID,例如cohere.command""" provider: str = None # type: ignore[assignment] """模型的提供者名称。默认为None,将尝试从模型ID中推导,否则需要用户输入。""" model_kwargs: Optional[Dict] = None """传递给模型的关键字参数""" service_endpoint: str = None # type: ignore[assignment] """服务端点URL""" compartment_id: str = None # type: ignore[assignment] """区段的OCID""" is_stream: bool = False """是否流式传输部分进度""" llm_stop_sequence_mapping: Mapping[str, str] = { "cohere": "stop_sequences", "meta": "stop", } @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证OCI配置和Python包是否存在于环境中。""" # Skip creating new client if passed in constructor if values["client"] is not None: return values try: import oci client_kwargs = { "config": {}, "signer": None, "service_endpoint": values["service_endpoint"], "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, "timeout": (10, 240), # default timeout config for OCI Gen AI service } if values["auth_type"] == OCIAuthType(1).name: client_kwargs["config"] = oci.config.from_file( profile_name=values["auth_profile"] ) client_kwargs.pop("signer", None) elif values["auth_type"] == OCIAuthType(2).name: def make_security_token_signer(oci_config): # type: ignore[no-untyped-def] pk = oci.signer.load_private_key_from_file( oci_config.get("key_file"), None ) with open( oci_config.get("security_token_file"), encoding="utf-8" ) as f: st_string = f.read() return oci.auth.signers.SecurityTokenSigner(st_string, pk) client_kwargs["config"] = oci.config.from_file( profile_name=values["auth_profile"] ) client_kwargs["signer"] = make_security_token_signer( oci_config=client_kwargs["config"] ) elif values["auth_type"] == OCIAuthType(3).name: client_kwargs[ "signer" ] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() elif values["auth_type"] == OCIAuthType(4).name: client_kwargs[ "signer" ] = oci.auth.signers.get_resource_principals_signer() else: raise ValueError("Please provide valid value to auth_type") values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient( **client_kwargs ) except ImportError as ex: raise ImportError( "Could not import oci python package. " "Please make sure you have the oci package installed." ) from ex except Exception as e: raise ValueError( "Could not authenticate with OCI client. " "Please check if ~/.oci/config exists. " "If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, " "Please check the specified " "auth_profile and auth_type are valid." ) from e return values @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" _model_kwargs = self.model_kwargs or {} return { **{"model_kwargs": _model_kwargs}, } def _get_provider(self) -> str: if self.provider is not None: provider = self.provider else: provider = self.model_id.split(".")[0].lower() if provider not in VALID_PROVIDERS: raise ValueError( f"Invalid provider derived from model_id: {self.model_id} " "Please explicitly pass in the supported provider " "when using custom endpoint" ) return provider
[docs]class OCIGenAI(LLM, OCIGenAIBase): """OCI大型语言模型。 要进行身份验证,OCI客户端使用https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm中描述的方法。 身份验证方法通过auth_type传递,应为以下之一: API_KEY(默认),SECURITY_TOKEN,INSTANCE_PRINCIPLE,RESOURCE_PRINCIPLE 确保您具有访问OCI生成式AI服务所需的策略(配置文件/角色)。 如果使用特定的配置文件配置文件,则必须通过auth_profile传递配置文件的名称(来自~/.oci/config)。 要使用,必须向构造函数提供区段ID、终端点URL和模型ID作为命名参数。 示例: .. code-block:: python from langchain_community.llms import OCIGenAI llm = OCIGenAI( model_id="MY_MODEL_ID", service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", compartment_id="MY_OCID" ) """ class Config: """此pydantic对象的配置。""" extra = Extra.forbid @property def _llm_type(self) -> str: """llm的返回类型。""" return "oci" def _prepare_invocation_object( self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any] ) -> Dict[str, Any]: from oci.generative_ai_inference import models oci_llm_request_mapping = { "cohere": models.CohereLlmInferenceRequest, "meta": models.LlamaLlmInferenceRequest, } provider = self._get_provider() _model_kwargs = self.model_kwargs or {} if stop is not None: _model_kwargs[self.llm_stop_sequence_mapping[provider]] = stop if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX): serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id) else: serving_mode = models.OnDemandServingMode(model_id=self.model_id) inference_params = {**_model_kwargs, **kwargs} inference_params["prompt"] = prompt inference_params["is_stream"] = self.is_stream invocation_obj = models.GenerateTextDetails( compartment_id=self.compartment_id, serving_mode=serving_mode, inference_request=oci_llm_request_mapping[provider](**inference_params), ) return invocation_obj def _process_response(self, response: Any, stop: Optional[List[str]]) -> str: provider = self._get_provider() if provider == "cohere": text = response.data.inference_response.generated_texts[0].text elif provider == "meta": text = response.data.inference_response.choices[0].text else: raise ValueError(f"Invalid provider: {provider}") if stop is not None: text = enforce_stop_tokens(text, stop) return text def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """调用OCIGenAI生成端点。 参数: prompt: 传递给模型的提示。 stop: 生成时可选的停止词列表。 返回: 模型生成的字符串。 示例: .. code-block:: python response = llm.invoke("Tell me a joke.") """ invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs) response = self.client.generate_text(invocation_obj) return self._process_response(response, stop)