classOpenVINOEmbedding(BaseEmbedding):model_id_or_path:str=Field(description="Huggingface model id or local path.")max_length:int=Field(description="Maximum length of input.")pooling:str=Field(description="Pooling strategy. One of ['cls', 'mean'].")normalize:bool=Field(default=True,description="Normalize embeddings or not.")query_instruction:Optional[str]=Field(description="Instruction to prepend to query text.")text_instruction:Optional[str]=Field(description="Instruction to prepend to text.")cache_folder:Optional[str]=Field(description="Cache folder for huggingface files.",default=None)_model:Any=PrivateAttr()_tokenizer:Any=PrivateAttr()_device:Any=PrivateAttr()def__init__(self,model_id_or_path:str="BAAI/bge-m3",pooling:str="cls",max_length:Optional[int]=None,normalize:bool=True,query_instruction:Optional[str]=None,text_instruction:Optional[str]=None,model:Optional[Any]=None,tokenizer:Optional[Any]=None,embed_batch_size:int=DEFAULT_EMBED_BATCH_SIZE,callback_manager:Optional[CallbackManager]=None,model_kwargs:Dict[str,Any]={},device:Optional[str]="auto",):try:fromhuggingface_hubimportHfApiexceptImportErrorase:raiseValueError("Could not import huggingface_hub python package. ""Please install it with: ""`pip install -U huggingface_hub`.")fromedefrequire_model_export(model_id:str,revision:Any=None,subfolder:Any=None)->bool:model_dir=Path(model_id)ifsubfolderisnotNone:model_dir=model_dir/subfolderifmodel_dir.is_dir():return(not(model_dir/"openvino_model.xml").exists()ornot(model_dir/"openvino_model.bin").exists())hf_api=HfApi()try:model_info=hf_api.model_info(model_id,revision=revisionor"main")normalized_subfolder=(NoneifsubfolderisNoneelsePath(subfolder).as_posix())model_files=[file.rfilenameforfileinmodel_info.siblingsifnormalized_subfolderisNoneorfile.rfilename.startswith(normalized_subfolder)]ov_model_path=("openvino_model.xml"ifsubfolderisNoneelsef"{normalized_subfolder}/openvino_model.xml")return(ov_model_pathnotinmodel_filesorov_model_path.replace(".xml",".bin")notinmodel_files)exceptException:returnTrueifrequire_model_export(model_id_or_path):# use remote modelmodel=modelorOVModelForFeatureExtraction.from_pretrained(model_id_or_path,export=True,device=device,**model_kwargs)else:# use local modelmodel=modelorOVModelForFeatureExtraction.from_pretrained(model_id_or_path,device=device,**model_kwargs)tokenizer=tokenizerorAutoTokenizer.from_pretrained(model_id_or_path)ifmax_lengthisNone:try:max_length=int(model.config.max_position_embeddings)exceptException:raiseValueError("Unable to find max_length from model config. ""Please provide max_length.")try:max_length=min(max_length,int(tokenizer.model_max_length))exceptExceptionasexc:print(f"An error occurred while retrieving tokenizer max length: {exc}")ifpoolingnotin["cls","mean"]:raiseValueError(f"Pooling {pooling} not supported.")super().__init__(embed_batch_size=embed_batch_size,callback_manager=callback_managerorCallbackManager([]),model_id_or_path=model_id_or_path,max_length=max_length,pooling=pooling,normalize=normalize,query_instruction=query_instruction,text_instruction=text_instruction,)self._device=deviceself._model=modelself._tokenizer=tokenizer@classmethoddefclass_name(cls)->str:return"OpenVINOEmbedding"@staticmethoddefcreate_and_save_openvino_model(model_name_or_path:str,output_path:str,export_kwargs:Optional[dict]=None,)->None:try:fromoptimum.intel.openvinoimportOVModelForFeatureExtractionfromtransformersimportAutoTokenizerfromoptimum.exporters.openvino.convertimportexport_tokenizerexceptImportError:raiseImportError("OpenVINO Embedding requires transformers and optimum to be installed.\n""Please install transformers with ""`pip install transformers optimum[openvino]`.")export_kwargs=export_kwargsor{}model=OVModelForFeatureExtraction.from_pretrained(model_name_or_path,export=True,compile=False,**export_kwargs)tokenizer=AutoTokenizer.from_pretrained(model_name_or_path)model.save_pretrained(output_path)tokenizer.save_pretrained(output_path)export_tokenizer(tokenizer,output_path)print(f"Saved OpenVINO model to {output_path}. Use it with "f"`embed_model = OpenVINOEmbedding(model_id_or_path='{output_path}')`.")def_mean_pooling(self,model_output:Any,attention_mask:Any)->Any:"""Mean Pooling - Take attention mask into account for correct averaging."""importtorch# First element of model_output contains all token embeddingstoken_embeddings=model_output[0]input_mask_expanded=(attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float())returntorch.sum(token_embeddings*input_mask_expanded,1)/torch.clamp(input_mask_expanded.sum(1),min=1e-9)def_cls_pooling(self,model_output:list)->Any:"""Use the CLS token as the pooling token."""returnmodel_output[0][:,0]def_embed(self,sentences:List[str])->List[List[float]]:"""Embed sentences."""length=self._model.request.inputs[0].get_partial_shape()[1]iflength.is_dynamic:encoded_input=self._tokenizer(sentences,padding=True,max_length=self.max_length,truncation=True,return_tensors="pt",)else:encoded_input=self._tokenizer(sentences,padding="max_length",max_length=length.get_length(),truncation=True,return_tensors="pt",)model_output=self._model(**encoded_input)ifself.pooling=="cls":embeddings=self._cls_pooling(model_output)else:embeddings=self._mean_pooling(model_output,encoded_input["attention_mask"])ifself.normalize:importtorchembeddings=torch.nn.functional.normalize(embeddings,p=2,dim=1)returnembeddings.tolist()def_get_query_embedding(self,query:str)->List[float]:"""Get query embedding."""query=format_query(query,self.model_name,self.query_instruction)returnself._embed([query])[0]asyncdef_aget_query_embedding(self,query:str)->List[float]:"""Get query embedding async."""returnself._get_query_embedding(query)asyncdef_aget_text_embedding(self,text:str)->List[float]:"""Get text embedding async."""returnself._get_text_embedding(text)def_get_text_embedding(self,text:str)->List[float]:"""Get text embedding."""text=format_text(text,self.model_name,self.text_instruction)returnself._embed([text])[0]def_get_text_embeddings(self,texts:List[str])->List[List[float]]:"""Get text embeddings."""texts=[format_text(text,self.model_name,self.text_instruction)fortextintexts]returnself._embed(texts)