importjsonfromtypingimportAny,Dict,List,Mapping,Optionalfromlangchain_core._api.deprecationimportdeprecatedfromlangchain_core.callbacksimportCallbackManagerForLLMRunfromlangchain_core.language_models.llmsimportLLMfromlangchain_core.utilsimportget_from_dict_or_env,pre_initfrompydanticimportConfigDictfromlangchain_community.llms.utilsimportenforce_stop_tokens# key: task# value: key in the output dictionaryVALID_TASKS_DICT={"translation":"translation_text","summarization":"summary_text","conversational":"generated_text","text-generation":"generated_text","text2text-generation":"generated_text",}
[docs]@deprecated("0.0.21",removal="1.0",alternative_import="langchain_huggingface.HuggingFaceEndpoint",)classHuggingFaceHub(LLM):"""HuggingFaceHub models. ! This class is deprecated, you should use HuggingFaceEndpoint instead. To use, you should have the ``huggingface_hub`` python package installed, and the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass it as a named parameter to the constructor. Supports `text-generation`, `text2text-generation`, `conversational`, `translation`, and `summarization`. Example: .. code-block:: python from langchain_community.llms import HuggingFaceHub hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key") """client:Any=None#: :meta private:repo_id:Optional[str]=None"""Model name to use. If not provided, the default model for the chosen task will be used."""task:Optional[str]=None"""Task to call the model with. Should be a task that returns `generated_text`, `summary_text`, or `translation_text`."""model_kwargs:Optional[dict]=None"""Keyword arguments to pass to the model."""huggingfacehub_api_token:Optional[str]=Nonemodel_config=ConfigDict(extra="forbid",)
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key and python package exists in environment."""huggingfacehub_api_token=get_from_dict_or_env(values,"huggingfacehub_api_token","HUGGINGFACEHUB_API_TOKEN")try:fromhuggingface_hubimportHfApi,InferenceClientrepo_id=values["repo_id"]client=InferenceClient(model=repo_id,token=huggingfacehub_api_token,)ifnotvalues["task"]:ifnotrepo_id:raiseValueError("Must specify either `repo_id` or `task`, or both.")# Use the recommended task for the chosen modelmodel_info=HfApi(token=huggingfacehub_api_token).model_info(repo_id=repo_id)values["task"]=model_info.pipeline_tagifvalues["task"]notinVALID_TASKS_DICT:raiseValueError(f"Got invalid task {values['task']}, "f"currently only {VALID_TASKS_DICT.keys()} are supported")values["client"]=clientexceptImportError:raiseImportError("Could not import huggingface_hub python package. ""Please install it with `pip install huggingface_hub`.")returnvalues
@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""_model_kwargs=self.model_kwargsor{}return{**{"repo_id":self.repo_id,"task":self.task},**{"model_kwargs":_model_kwargs},}@propertydef_llm_type(self)->str:"""Return type of llm."""return"huggingface_hub"def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to HuggingFace Hub's inference endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = hf("Tell me a joke.") """_model_kwargs=self.model_kwargsor{}parameters={**_model_kwargs,**kwargs}response=self.client.post(json={"inputs":prompt,"parameters":parameters},task=self.task)response=json.loads(response.decode())if"error"inresponse:raiseValueError(f"Error raised by inference API: {response['error']}")response_key=VALID_TASKS_DICT[self.task]# type: ignoreifisinstance(response,list):text=response[0][response_key]else:text=response[response_key]ifstopisnotNone:# This is a bit hacky, but I can't figure out a better way to enforce# stop tokens when making calls to huggingface_hub.text=enforce_stop_tokens(text,stop)returntext