33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255 | class OCIGenAIEmbeddings(BaseEmbedding):
"""OCI嵌入模型。
要进行身份验证,OCI客户端使用https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm中描述的方法。
身份验证方法通过auth_type传递,应为以下之一:
API_KEY(默认)、SECURITY_TOKEN、INSTANCE_PRINCIPAL、RESOURCE_PRINCIPAL
确保您具有访问OCI生成式AI服务所需的策略(配置文件/角色)。如果使用特定的配置文件配置文件,则必须通过auth_profile传递配置文件的名称(~/.oci/config)。
要使用,必须在构造函数的命名参数中提供区域ID、终端URL和模型ID。
示例:
.. code-block:: python
from llama_index.embeddings.oci_genai import OCIGenAIEmbeddings
embeddings = OCIGenAIEmbeddings(
model_name="MY_EMBEDDING_MODEL",
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
compartment_id="MY_OCID"
)
"""
model_name: str = Field(
description="ID or Name of the OCI Generative AI embedding model to use."
)
truncate: str = Field(
description="Truncate embeddings that are too long from start or end, values START/ END/ NONE",
default="END",
)
input_type: Optional[str] = Field(
description="Model Input type. If not provided, search_document and search_query are used when needed.",
default=None,
)
service_endpoint: str = Field(
description="service endpoint url.",
default=None,
)
compartment_id: str = Field(
description="OCID of compartment.",
default=None,
)
auth_type: Optional[str] = Field(
description="Authentication type, can be: API_KEY, SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL. If not specified, API_KEY will be used",
default="API_KEY",
)
auth_profile: Optional[str] = Field(
description="The name of the profile in ~/.oci/config. If not specified , DEFAULT will be used",
default="DEFAULT",
)
_client: Any = PrivateAttr()
def __init__(
self,
model_name: str,
truncate: str = "END",
input_type: Optional[str] = None,
service_endpoint: str = None,
compartment_id: str = None,
auth_type: Optional[str] = "API_KEY",
auth_profile: Optional[str] = "DEFAULT",
client: Optional[Any] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
):
"""初始化OCIGenAIEmbeddings类。
Args:
model_name(str):用于生成嵌入的模型的名称或ID,例如"cohere.embed-english-light-v3.0"。
truncate(str):表示长输入文本的截断策略的字符串。可能的取值为'START'、'END'或'NONE'。
input_type(Optional[str]):可选字符串,指定提供给模型的输入类型。这取决于模型,可能是以下之一:"search_query"、"search_document"、"classification"或"clustering"。
service_endpoint(str):服务端点URL,例如"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"。
compartment_id(str):区段的OCID。
auth_type(Optional[str]):认证类型,可以是:API_KEY(默认)、SECURITY_TOKEN、INSTANCEAL、RESOURCE_PRINCIPAL。如果未指定,将使用API_KEY。
auth_profile(Optional[str]):~/.oci/config中的配置文件的名称。如果未指定,将使用DEFAULT。
client(Optional[Any]):可选的OCI客户端对象。如果未提供,将使用提供的服务端点和认证方法创建客户端。
"""
if client is not None:
self._client = client
else:
try:
import oci
client_kwargs = {
"config": {},
"signer": None,
"service_endpoint": service_endpoint,
"retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY,
"timeout": (
10,
240,
), # default timeout config for OCI Gen AI service
}
if auth_type == OCIAuthType(1).name:
client_kwargs["config"] = oci.config.from_file(
profile_name=auth_profile
)
client_kwargs.pop("signer", None)
elif auth_type == OCIAuthType(2).name:
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None
)
with open(
oci_config.get("security_token_file"), encoding="utf-8"
) as f:
st_string = f.read()
return oci.auth.signers.SecurityTokenSigner(st_string, pk)
client_kwargs["config"] = oci.config.from_file(
profile_name=auth_profile
)
client_kwargs["signer"] = make_security_token_signer(
oci_config=client_kwargs["config"]
)
elif auth_type == OCIAuthType(3).name:
client_kwargs[
"signer"
] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
elif auth_type == OCIAuthType(4).name:
client_kwargs[
"signer"
] = oci.auth.signers.get_resource_principals_signer()
else:
raise ValueError(
f"Please provide valid value to auth_type, {auth_type} is not valid."
)
self._client = oci.generative_ai_inference.GenerativeAiInferenceClient(
**client_kwargs
)
except ImportError as ex:
raise ModuleNotFoundError(
"Could not import oci python package. "
"Please make sure you have the oci package installed."
) from ex
except Exception as e:
raise ValueError(
"""Could not authenticate with OCI client. Please check if ~/.oci/config exists.
If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used, please check the specified
auth_profile and auth_type are valid.""",
e,
) from e
super().__init__(
model_name=model_name,
truncate=truncate,
input_type=input_type,
service_endpoint=service_endpoint,
compartment_id=compartment_id,
auth_type=auth_type,
auth_profile=auth_profile,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
)
@classmethod
def class_name(self) -> str:
return "OCIGenAIEmbeddings"
@staticmethod
def list_supported_models() -> List[str]:
return list(SUPPORTED_MODELS)
def _embed(self, texts: List[str], input_type: str) -> List[List[float]]:
try:
from oci.generative_ai_inference import models
except ImportError as ex:
raise ModuleNotFoundError(
"Could not import oci python package. "
"Please make sure you have the oci package installed."
) from ex
if self.model_name.startswith(CUSTOM_ENDPOINT_PREFIX):
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_name)
else:
serving_mode = models.OnDemandServingMode(model_id=self.model_name)
request = models.EmbedTextDetails(
serving_mode=serving_mode,
compartment_id=self.compartment_id,
input_type=self.input_type or input_type,
truncate=self.truncate,
inputs=texts,
)
response = self._client.embed_text(request)
return response.data.embeddings
def _get_query_embedding(self, query: str) -> List[float]:
return self._embed([query], input_type="SEARCH_QUERY")[0]
def _get_text_embedding(self, text: str) -> List[float]:
return self._embed([text], input_type="SEARCH_DOCUMENT")[0]
def _get_text_embeddings(self, text: str) -> List[List[float]]:
return self._embed(text, input_type="SEARCH_DOCUMENT")
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
|