自动类
在许多情况下,您想要使用的架构可以从您提供给from_pretrained()
方法的预训练模型的名称或路径中猜测出来。AutoClasses 在这里为您完成这项工作,以便您根据预训练权重/配置/词汇表的名称/路径自动检索相关模型。
实例化AutoConfig、AutoModel和AutoTokenizer中的一个将直接创建相关架构的类。例如
model = AutoModel.from_pretrained("google-bert/bert-base-cased")
将创建一个模型,该模型是BertModel的实例。
每个任务和每个后端(PyTorch、TensorFlow 或 Flax)都有一个 AutoModel
类。
扩展自动类
每个自动类都有一个方法,可以用您的自定义类进行扩展。例如,如果您定义了一个自定义模型类 NewModel
,请确保您有一个 NewModelConfig
,然后您可以像这样将它们添加到自动类中:
from transformers import AutoConfig, AutoModel
AutoConfig.register("new-model", NewModelConfig)
AutoModel.register(NewModelConfig, NewModel)
然后你就可以像平常一样使用自动类了!
如果你的 NewModelConfig
是 PretrainedConfig 的子类,请确保其
model_type
属性设置为与注册配置时使用的键相同(此处为 "new-model"
)。
同样,如果你的NewModel
是PreTrainedModel的子类,请确保其config_class
属性设置为注册模型时使用的相同类(这里是NewModelConfig
)。
自动配置
这是一个通用的配置类,当使用from_pretrained()类方法创建时,它将作为库的配置类之一实例化。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.
- A path to a directory containing a configuration file saved using the
save_pretrained() method, or the save_pretrained() method,
e.g.,
./my_model_directory/
. - A path or url to a saved configuration JSON file, e.g.,
./my_model_directory/configuration.json
.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - force_download (
bool
, optional, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,并覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能时都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - return_unused_kwargs (
bool
, optional, defaults toFalse
) — IfFalse
, then this function returns just the final configuration object.如果
True
,则此函数返回一个Tuple(config, unused_kwargs)
,其中unused_kwargs是一个由键/值对组成的字典,这些键不是配置属性:即,kwargs
中未用于更新config
的部分,并且被忽略。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - kwargs(additional 关键字参数, 可选) —
kwargs 中任何键的值,如果这些键是配置属性,将用于覆盖已加载的值。关于键/值对的行为,如果键不是配置属性,则由
return_unused_kwargs
关键字参数控制。
从预训练模型配置实例化库中的一个配置类。
要实例化的配置类是基于加载的配置对象的model_type
属性选择的,或者当该属性缺失时,通过回退到使用pretrained_model_name_or_path
上的模式匹配来选择:
- albert — AlbertConfig (ALBERT 模型)
- align — AlignConfig (ALIGN 模型)
- altclip — AltCLIPConfig (AltCLIP 模型)
- audio-spectrogram-transformer — ASTConfig (音频频谱变换器模型)
- autoformer — AutoformerConfig (Autoformer 模型)
- bark — BarkConfig (Bark 模型)
- bart — BartConfig (BART 模型)
- beit — BeitConfig (BEiT 模型)
- bert — BertConfig (BERT 模型)
- bert-generation — BertGenerationConfig (Bert 生成模型)
- big_bird — BigBirdConfig (BigBird 模型)
- bigbird_pegasus — BigBirdPegasusConfig (BigBird-Pegasus 模型)
- biogpt — BioGptConfig (BioGpt 模型)
- bit — BitConfig (BiT 模型)
- blenderbot — BlenderbotConfig (Blenderbot 模型)
- blenderbot-small — BlenderbotSmallConfig (BlenderbotSmall 模型)
- blip — BlipConfig (BLIP 模型)
- blip-2 — Blip2Config (BLIP-2 模型)
- bloom — BloomConfig (BLOOM 模型)
- bridgetower — BridgeTowerConfig (BridgeTower 模型)
- bros — BrosConfig (BROS 模型)
- camembert — CamembertConfig (CamemBERT 模型)
- canine — CanineConfig (CANINE 模型)
- chameleon — ChameleonConfig (变色龙模型)
- chinese_clip — ChineseCLIPConfig (中文-CLIP 模型)
- chinese_clip_vision_model — ChineseCLIPVisionConfig (ChineseCLIPVisionModel 模型)
- clap — ClapConfig (CLAP 模型)
- clip — CLIPConfig (CLIP 模型)
- clip_text_model — CLIPTextConfig (CLIPTextModel 模型)
- clip_vision_model — CLIPVisionConfig (CLIPVisionModel 模型)
- clipseg — CLIPSegConfig (CLIPSeg 模型)
- clvp — ClvpConfig (CLVP 模型)
- code_llama — LlamaConfig (CodeLlama 模型)
- codegen — CodeGenConfig (CodeGen 模型)
- cohere — CohereConfig (Cohere 模型)
- conditional_detr — ConditionalDetrConfig (条件DETR模型)
- convbert — ConvBertConfig (ConvBERT 模型)
- convnext — ConvNextConfig (ConvNeXT 模型)
- convnextv2 — ConvNextV2Config (ConvNeXTV2 模型)
- cpmant — CpmAntConfig (CPM-Ant 模型)
- ctrl — CTRLConfig (CTRL 模型)
- cvt — CvtConfig (CvT 模型)
- dac — DacConfig (DAC 模型)
- data2vec-audio — Data2VecAudioConfig (Data2VecAudio 模型)
- data2vec-text — Data2VecTextConfig (Data2VecText 模型)
- data2vec-vision — Data2VecVisionConfig (Data2VecVision 模型)
- dbrx — DbrxConfig (DBRX 模型)
- deberta — DebertaConfig (DeBERTa 模型)
- deberta-v2 — DebertaV2Config (DeBERTa-v2 模型)
- decision_transformer — DecisionTransformerConfig(决策变换器模型)
- deformable_detr — DeformableDetrConfig (可变形DETR模型)
- deit — DeiTConfig (DeiT 模型)
- depth_anything — DepthAnythingConfig (Depth Anything 模型)
- deta — DetaConfig (DETA 模型)
- detr — DetrConfig (DETR 模型)
- dinat — DinatConfig (DiNAT 模型)
- dinov2 — Dinov2Config (DINOv2 模型)
- distilbert — DistilBertConfig (DistilBERT 模型)
- donut-swin — DonutSwinConfig (DonutSwin 模型)
- dpr — DPRConfig (DPR 模型)
- dpt — DPTConfig (DPT 模型)
- efficientformer — EfficientFormerConfig (EfficientFormer 模型)
- efficientnet — EfficientNetConfig (EfficientNet 模型)
- electra — ElectraConfig (ELECTRA 模型)
- encodec — EncodecConfig (EnCodec 模型)
- 编码器-解码器 — EncoderDecoderConfig (编码器解码器模型)
- ernie — ErnieConfig (ERNIE 模型)
- ernie_m — ErnieMConfig (ErnieM 模型)
- esm — EsmConfig (ESM 模型)
- falcon — FalconConfig (Falcon 模型)
- falcon_mamba — FalconMambaConfig (FalconMamba 模型)
- fastspeech2_conformer — FastSpeech2ConformerConfig (FastSpeech2Conformer 模型)
- flaubert — FlaubertConfig (FlauBERT 模型)
- flava — FlavaConfig (FLAVA 模型)
- fnet — FNetConfig (FNet 模型)
- focalnet — FocalNetConfig (FocalNet 模型)
- fsmt — FSMTConfig (FairSeq 机器翻译模型)
- funnel — FunnelConfig (漏斗变换器模型)
- fuyu — FuyuConfig (Fuyu 模型)
- gemma — GemmaConfig (Gemma 模型)
- gemma2 — Gemma2Config (Gemma2 模型)
- git — GitConfig (GIT 模型)
- glm — GlmConfig (GLM 模型)
- glpn — GLPNConfig (GLPN 模型)
- gpt-sw3 — GPT2Config (GPT-Sw3 模型)
- gpt2 — GPT2Config (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeConfig (GPTBigCode 模型)
- gpt_neo — GPTNeoConfig (GPT Neo 模型)
- gpt_neox — GPTNeoXConfig (GPT NeoX 模型)
- gpt_neox_japanese — GPTNeoXJapaneseConfig (GPT NeoX 日语模型)
- gptj — GPTJConfig (GPT-J 模型)
- gptsan-japanese — GPTSanJapaneseConfig (GPTSAN-日语模型)
- granite — GraniteConfig (Granite 模型)
- granitemoe — GraniteMoeConfig (GraniteMoeMoe 模型)
- graphormer — GraphormerConfig (Graphormer 模型)
- grounding-dino — GroundingDinoConfig (Grounding DINO 模型)
- groupvit — GroupViTConfig (GroupViT 模型)
- hiera — HieraConfig (Hiera 模型)
- hubert — HubertConfig (Hubert 模型)
- ibert — IBertConfig (I-BERT 模型)
- idefics — IdeficsConfig (IDEFICS 模型)
- idefics2 — Idefics2Config (Idefics2 模型)
- idefics3 — Idefics3Config (Idefics3 模型)
- ijepa — IJepaConfig (I-JEPA 模型)
- imagegpt — ImageGPTConfig (ImageGPT 模型)
- informer — InformerConfig (Informer 模型)
- instructblip — InstructBlipConfig (InstructBLIP 模型)
- instructblipvideo — InstructBlipVideoConfig (InstructBlipVideo 模型)
- jamba — JambaConfig (Jamba 模型)
- jetmoe — JetMoeConfig (JetMoe 模型)
- jukebox — JukeboxConfig (点唱机模型)
- kosmos-2 — Kosmos2Config (KOSMOS-2 模型)
- layoutlm — LayoutLMConfig (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2Config (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3Config (LayoutLMv3 模型)
- led — LEDConfig (LED 模型)
- levit — LevitConfig (LeViT 模型)
- lilt — LiltConfig (LiLT 模型)
- llama — LlamaConfig (LLaMA 模型)
- llava — LlavaConfig (LLaVa 模型)
- llava_next — LlavaNextConfig (LLaVA-NeXT 模型)
- llava_next_video — LlavaNextVideoConfig (LLaVa-NeXT-Video 模型)
- llava_onevision — LlavaOnevisionConfig (LLaVA-Onevision 模型)
- longformer — LongformerConfig (Longformer 模型)
- longt5 — LongT5Config (LongT5 模型)
- luke — LukeConfig (LUKE 模型)
- lxmert — LxmertConfig (LXMERT 模型)
- m2m_100 — M2M100Config (M2M100 模型)
- mamba — MambaConfig (Mamba 模型)
- mamba2 — Mamba2Config (mamba2 模型)
- marian — MarianConfig (Marian 模型)
- markuplm — MarkupLMConfig (MarkupLM 模型)
- mask2former — Mask2FormerConfig (Mask2Former 模型)
- maskformer — MaskFormerConfig (MaskFormer 模型)
- maskformer-swin —
MaskFormerSwinConfig
(MaskFormerSwin 模型) - mbart — MBartConfig (mBART 模型)
- mctct — MCTCTConfig (M-CTC-T 模型)
- mega — MegaConfig (MEGA 模型)
- megatron-bert — MegatronBertConfig (Megatron-BERT 模型)
- mgp-str — MgpstrConfig (MGP-STR 模型)
- mimi — MimiConfig (Mimi 模型)
- mistral — MistralConfig (Mistral 模型)
- mixtral — MixtralConfig (Mixtral 模型)
- mllama — MllamaConfig (Mllama 模型)
- mobilebert — MobileBertConfig (MobileBERT 模型)
- mobilenet_v1 — MobileNetV1Config (MobileNetV1 模型)
- mobilenet_v2 — MobileNetV2Config (MobileNetV2 模型)
- mobilevit — MobileViTConfig (MobileViT 模型)
- mobilevitv2 — MobileViTV2Config (MobileViTV2 模型)
- moshi — MoshiConfig (Moshi 模型)
- mpnet — MPNetConfig (MPNet 模型)
- mpt — MptConfig (MPT 模型)
- mra — MraConfig (MRA 模型)
- mt5 — MT5Config (MT5 模型)
- musicgen — MusicgenConfig (MusicGen 模型)
- musicgen_melody — MusicgenMelodyConfig (MusicGen 旋律模型)
- mvp — MvpConfig (MVP 模型)
- nat — NatConfig (NAT 模型)
- nemotron — NemotronConfig (Nemotron 模型)
- nezha — NezhaConfig (Nezha 模型)
- nllb-moe — NllbMoeConfig (NLLB-MOE 模型)
- nougat — VisionEncoderDecoderConfig (Nougat 模型)
- nystromformer — NystromformerConfig (Nyströmformer 模型)
- olmo — OlmoConfig (OLMo 模型)
- olmo2 — Olmo2Config (OLMo2 模型)
- olmoe — OlmoeConfig (OLMoE 模型)
- omdet-turbo — OmDetTurboConfig (OmDet-Turbo 模型)
- oneformer — OneFormerConfig (OneFormer 模型)
- open-llama — OpenLlamaConfig (OpenLlama 模型)
- openai-gpt — OpenAIGPTConfig (OpenAI GPT 模型)
- opt — OPTConfig (OPT 模型)
- owlv2 — Owlv2Config (OWLv2 模型)
- owlvit — OwlViTConfig (OWL-ViT 模型)
- paligemma — PaliGemmaConfig (PaliGemma 模型)
- patchtsmixer — PatchTSMixerConfig (PatchTSMixer 模型)
- patchtst — PatchTSTConfig (PatchTST 模型)
- pegasus — PegasusConfig (Pegasus 模型)
- pegasus_x — PegasusXConfig (PEGASUS-X 模型)
- perceiver — PerceiverConfig (感知器模型)
- 柿子 — PersimmonConfig(柿子模型)
- phi — PhiConfig (Phi 模型)
- phi3 — Phi3Config (Phi3 模型)
- phimoe — PhimoeConfig (Phimoe 模型)
- pix2struct — Pix2StructConfig (Pix2Struct 模型)
- pixtral — PixtralVisionConfig (Pixtral 模型)
- plbart — PLBartConfig (PLBart 模型)
- poolformer — PoolFormerConfig (PoolFormer 模型)
- pop2piano — Pop2PianoConfig (Pop2Piano 模型)
- prophetnet — ProphetNetConfig (ProphetNet 模型)
- pvt — PvtConfig (PVT 模型)
- pvt_v2 — PvtV2Config (PVTv2 模型)
- qdqbert — QDQBertConfig (QDQBert 模型)
- qwen2 — Qwen2Config (Qwen2 模型)
- qwen2_audio — Qwen2AudioConfig (Qwen2Audio 模型)
- qwen2_audio_encoder — Qwen2AudioEncoderConfig (Qwen2AudioEncoder 模型)
- qwen2_moe — Qwen2MoeConfig (Qwen2MoE 模型)
- qwen2_vl — Qwen2VLConfig (Qwen2VL 模型)
- rag — RagConfig (RAG 模型)
- realm — RealmConfig (REALM 模型)
- recurrent_gemma — RecurrentGemmaConfig (RecurrentGemma 模型)
- reformer — ReformerConfig (Reformer 模型)
- regnet — RegNetConfig (RegNet 模型)
- rembert — RemBertConfig (RemBERT 模型)
- resnet — ResNetConfig (ResNet 模型)
- retribert — RetriBertConfig (RetriBERT 模型)
- roberta — RobertaConfig (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormConfig (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertConfig (RoCBert 模型)
- roformer — RoFormerConfig (RoFormer 模型)
- rt_detr — RTDetrConfig (RT-DETR 模型)
- rt_detr_resnet — RTDetrResNetConfig (RT-DETR-ResNet 模型)
- rwkv — RwkvConfig (RWKV 模型)
- sam — SamConfig (SAM 模型)
- seamless_m4t — SeamlessM4TConfig (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4Tv2Config (SeamlessM4Tv2 模型)
- segformer — SegformerConfig (SegFormer 模型)
- seggpt — SegGptConfig (SegGPT 模型)
- sew — SEWConfig (SEW 模型)
- sew-d — SEWDConfig (SEW-D 模型)
- siglip — SiglipConfig (SigLIP 模型)
- siglip_vision_model — SiglipVisionConfig (SiglipVisionModel 模型)
- speech-encoder-decoder — SpeechEncoderDecoderConfig (语音编码解码模型)
- speech_to_text — Speech2TextConfig (语音转文本模型)
- speech_to_text_2 — Speech2Text2Config (Speech2Text2 模型)
- speecht5 — SpeechT5Config (SpeechT5 模型)
- splinter — SplinterConfig (Splinter 模型)
- squeezebert — SqueezeBertConfig (SqueezeBERT 模型)
- stablelm — StableLmConfig (StableLm 模型)
- starcoder2 — Starcoder2Config (Starcoder2 模型)
- superpoint — SuperPointConfig (SuperPoint 模型)
- swiftformer — SwiftFormerConfig (SwiftFormer 模型)
- swin — SwinConfig (Swin Transformer 模型)
- swin2sr — Swin2SRConfig (Swin2SR 模型)
- swinv2 — Swinv2Config (Swin Transformer V2 模型)
- switch_transformers — SwitchTransformersConfig (SwitchTransformers 模型)
- t5 — T5Config (T5 模型)
- table-transformer — TableTransformerConfig (表格转换器模型)
- tapas — TapasConfig (TAPAS 模型)
- time_series_transformer — TimeSeriesTransformerConfig (时间序列变换器模型)
- timesformer — TimesformerConfig (TimeSformer 模型)
- timm_backbone — TimmBackboneConfig (TimmBackbone 模型)
- trajectory_transformer — TrajectoryTransformerConfig (轨迹变换器模型)
- transfo-xl — TransfoXLConfig (Transformer-XL 模型)
- trocr — TrOCRConfig (TrOCR 模型)
- tvlt — TvltConfig (TVLT 模型)
- tvp — TvpConfig (TVP 模型)
- udop — UdopConfig (UDOP 模型)
- umt5 — UMT5Config (UMT5 模型)
- unispeech — UniSpeechConfig (UniSpeech 模型)
- unispeech-sat — UniSpeechSatConfig (UniSpeechSat 模型)
- univnet — UnivNetConfig (UnivNet 模型)
- upernet — UperNetConfig (UPerNet 模型)
- van — VanConfig (VAN 模型)
- video_llava — VideoLlavaConfig (VideoLlava 模型)
- videomae — VideoMAEConfig (VideoMAE 模型)
- vilt — ViltConfig (ViLT 模型)
- vipllava — VipLlavaConfig (VipLlava 模型)
- vision-encoder-decoder — VisionEncoderDecoderConfig (视觉编码解码模型)
- vision-text-dual-encoder — VisionTextDualEncoderConfig (VisionTextDualEncoder 模型)
- visual_bert — VisualBertConfig (VisualBERT 模型)
- vit — ViTConfig (ViT 模型)
- vit_hybrid — ViTHybridConfig (ViT 混合模型)
- vit_mae — ViTMAEConfig (ViTMAE 模型)
- vit_msn — ViTMSNConfig (ViTMSN 模型)
- vitdet — VitDetConfig (VitDet 模型)
- vitmatte — VitMatteConfig (ViTMatte 模型)
- vits — VitsConfig (VITS 模型)
- vivit — VivitConfig (ViViT 模型)
- wav2vec2 — Wav2Vec2Config (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertConfig (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerConfig (Wav2Vec2-Conformer 模型)
- wavlm — WavLMConfig (WavLM 模型)
- whisper — WhisperConfig (Whisper 模型)
- xclip — XCLIPConfig (X-CLIP 模型)
- xglm — XGLMConfig (XGLM 模型)
- xlm — XLMConfig (XLM 模型)
- xlm-prophetnet — XLMProphetNetConfig (XLM-ProphetNet 模型)
- xlm-roberta — XLMRobertaConfig (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLConfig (XLM-RoBERTa-XL 模型)
- xlnet — XLNetConfig (XLNet 模型)
- xmod — XmodConfig (X-MOD 模型)
- yolos — YolosConfig (YOLOS 模型)
- yoso — YosoConfig (YOSO 模型)
- zamba — ZambaConfig (Zamba 模型)
- zoedepth — ZoeDepthConfig (ZoeDepth 模型)
示例:
>>> from transformers import AutoConfig
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
>>> # Download configuration from huggingface.co (user-uploaded) and cache.
>>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
>>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
>>> # Load a specific configuration file.
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
>>> # Change some config attributes when loading a pretrained config.
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
>>> config.output_attentions
True
>>> config, unused_kwargs = AutoConfig.from_pretrained(
... "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
... )
>>> config.output_attentions
True
>>> unused_kwargs
{'foo': False}
注册
< source >( model_type config exist_ok = False )
参数
- model_type (
str
) — 模型类型,如“bert”或“gpt”。 - config (PretrainedConfig) — 要注册的配置。
为此类注册一个新配置。
AutoTokenizer
这是一个通用的分词器类,当使用AutoTokenizer.from_pretrained()类方法创建时,它将作为库中的一个分词器类实例化。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_pretrained
< source >( pretrained_model_name_or_path *inputs **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a predefined tokenizer hosted inside a model repo on huggingface.co.
- A path to a directory containing vocabulary files required by the tokenizer, for instance saved
using the save_pretrained() method, e.g.,
./my_model_directory/
. - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
single vocabulary file (like Bert or XLNet), e.g.:
./my_model_directory/vocab.txt
. (Not applicable to all derived classes)
- inputs (额外的位置参数, 可选) —
将被传递给 Tokenizer 的
__init__()
方法. - config (PretrainedConfig, 可选) — 用于确定要实例化的分词器类的配置对象。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,并覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - 子文件夹 (
str
, 可选) — 如果相关文件位于huggingface.co上模型仓库的子文件夹中(例如facebook/rag-token-base),请在此处指定它。 - use_fast (
bool
, 可选, 默认为True
) — 如果支持给定模型,则使用基于Rust的快速分词器。如果给定模型没有快速分词器,则返回基于Python的普通分词器。 - tokenizer_type (
str
, optional) — 要加载的分词器类型. - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - kwargs(额外的关键字参数,可选)—
将被传递给Tokenizer的
__init__()
方法。可用于设置特殊标记,如bos_token
、eos_token
、unk_token
、sep_token
、pad_token
、cls_token
、mask_token
、additional_special_tokens
。有关更多详细信息,请参阅__init__()
中的参数。
从预训练模型词汇表中实例化库中的一个分词器类。
要实例化的分词器类是基于配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当该属性缺失时,通过回退到使用pretrained_model_name_or_path
上的模式匹配来选择:
- albert — AlbertTokenizer 或 AlbertTokenizerFast (ALBERT 模型)
- align — BertTokenizer 或 BertTokenizerFast (ALIGN 模型)
- bark — BertTokenizer 或 BertTokenizerFast (Bark 模型)
- bart — BartTokenizer 或 BartTokenizerFast (BART 模型)
- barthez — BarthezTokenizer 或 BarthezTokenizerFast (BARThez 模型)
- bartpho — BartphoTokenizer (BARTpho 模型)
- bert — BertTokenizer 或 BertTokenizerFast (BERT 模型)
- bert-generation — BertGenerationTokenizer (Bert 生成模型)
- bert-japanese — BertJapaneseTokenizer (BertJapanese 模型)
- bertweet — BertweetTokenizer (BERTweet 模型)
- big_bird — BigBirdTokenizer 或 BigBirdTokenizerFast (BigBird 模型)
- bigbird_pegasus — PegasusTokenizer 或 PegasusTokenizerFast (BigBird-Pegasus 模型)
- biogpt — BioGptTokenizer (BioGpt 模型)
- blenderbot — BlenderbotTokenizer 或 BlenderbotTokenizerFast (Blenderbot 模型)
- blenderbot-small — BlenderbotSmallTokenizer (BlenderbotSmall 模型)
- blip — BertTokenizer 或 BertTokenizerFast (BLIP 模型)
- blip-2 — GPT2Tokenizer 或 GPT2TokenizerFast (BLIP-2 模型)
- bloom — BloomTokenizerFast (BLOOM 模型)
- bridgetower — RobertaTokenizer 或 RobertaTokenizerFast (BridgeTower 模型)
- bros — BertTokenizer 或 BertTokenizerFast (BROS 模型)
- byt5 — ByT5Tokenizer (ByT5 模型)
- camembert — CamembertTokenizer 或 CamembertTokenizerFast (CamemBERT 模型)
- canine — CanineTokenizer (CANINE 模型)
- chameleon — LlamaTokenizer 或 LlamaTokenizerFast (Chameleon 模型)
- chinese_clip — BertTokenizer 或 BertTokenizerFast (Chinese-CLIP 模型)
- clap — RobertaTokenizer 或 RobertaTokenizerFast (CLAP 模型)
- clip — CLIPTokenizer 或 CLIPTokenizerFast (CLIP 模型)
- clipseg — CLIPTokenizer 或 CLIPTokenizerFast (CLIPSeg 模型)
- clvp — ClvpTokenizer (CLVP 模型)
- code_llama — CodeLlamaTokenizer 或 CodeLlamaTokenizerFast (CodeLlama 模型)
- codegen — CodeGenTokenizer 或 CodeGenTokenizerFast (CodeGen 模型)
- cohere — CohereTokenizerFast (Cohere 模型)
- convbert — ConvBertTokenizer 或 ConvBertTokenizerFast (ConvBERT 模型)
- cpm — CpmTokenizer 或 CpmTokenizerFast (CPM 模型)
- cpmant — CpmAntTokenizer (CPM-Ant 模型)
- ctrl — CTRLTokenizer (CTRL 模型)
- data2vec-audio — Wav2Vec2CTCTokenizer (Data2VecAudio 模型)
- data2vec-text — RobertaTokenizer 或 RobertaTokenizerFast (Data2VecText 模型)
- dbrx — GPT2Tokenizer 或 GPT2TokenizerFast (DBRX 模型)
- deberta — DebertaTokenizer 或 DebertaTokenizerFast (DeBERTa 模型)
- deberta-v2 — DebertaV2Tokenizer 或 DebertaV2TokenizerFast (DeBERTa-v2 模型)
- distilbert — DistilBertTokenizer 或 DistilBertTokenizerFast (DistilBERT 模型)
- dpr — DPRQuestionEncoderTokenizer 或 DPRQuestionEncoderTokenizerFast (DPR 模型)
- electra — ElectraTokenizer 或 ElectraTokenizerFast (ELECTRA 模型)
- ernie — BertTokenizer 或 BertTokenizerFast (ERNIE 模型)
- ernie_m — ErnieMTokenizer (ErnieM 模型)
- esm — EsmTokenizer (ESM 模型)
- falcon — PreTrainedTokenizerFast (Falcon 模型)
- falcon_mamba — GPTNeoXTokenizerFast (FalconMamba 模型)
- fastspeech2_conformer — (FastSpeech2Conformer 模型)
- flaubert — FlaubertTokenizer (FlauBERT 模型)
- fnet — FNetTokenizer 或 FNetTokenizerFast (FNet 模型)
- fsmt — FSMTTokenizer (FairSeq 机器翻译模型)
- funnel — FunnelTokenizer 或 FunnelTokenizerFast (Funnel Transformer 模型)
- gemma — GemmaTokenizer 或 GemmaTokenizerFast (Gemma 模型)
- gemma2 — GemmaTokenizer 或 GemmaTokenizerFast (Gemma2 模型)
- git — BertTokenizer 或 BertTokenizerFast (GIT 模型)
- glm — PreTrainedTokenizerFast (GLM 模型)
- gpt-sw3 — GPTSw3Tokenizer (GPT-Sw3 模型)
- gpt2 — GPT2Tokenizer 或 GPT2TokenizerFast (OpenAI GPT-2 模型)
- gpt_bigcode — GPT2Tokenizer 或 GPT2TokenizerFast (GPTBigCode 模型)
- gpt_neo — GPT2Tokenizer 或 GPT2TokenizerFast (GPT Neo 模型)
- gpt_neox — GPTNeoXTokenizerFast (GPT NeoX 模型)
- gpt_neox_japanese — GPTNeoXJapaneseTokenizer (GPT NeoX 日语模型)
- gptj — GPT2Tokenizer 或 GPT2TokenizerFast (GPT-J 模型)
- gptsan-japanese — GPTSanJapaneseTokenizer (GPTSAN-日语模型)
- grounding-dino — BertTokenizer 或 BertTokenizerFast (Grounding DINO 模型)
- groupvit — CLIPTokenizer 或 CLIPTokenizerFast (GroupViT 模型)
- herbert — HerbertTokenizer 或 HerbertTokenizerFast (HerBERT 模型)
- hubert — Wav2Vec2CTCTokenizer (Hubert 模型)
- ibert — RobertaTokenizer 或 RobertaTokenizerFast (I-BERT 模型)
- idefics — LlamaTokenizerFast (IDEFICS 模型)
- idefics2 — LlamaTokenizer 或 LlamaTokenizerFast (Idefics2 模型)
- idefics3 — LlamaTokenizer 或 LlamaTokenizerFast (Idefics3 模型)
- instructblip — GPT2Tokenizer 或 GPT2TokenizerFast (InstructBLIP 模型)
- instructblipvideo — GPT2Tokenizer 或 GPT2TokenizerFast (InstructBlipVideo 模型)
- jamba — LlamaTokenizer 或 LlamaTokenizerFast (Jamba 模型)
- jetmoe — LlamaTokenizer 或 LlamaTokenizerFast (JetMoe 模型)
- 点唱机 — JukeboxTokenizer (Jukebox 模型)
- kosmos-2 — XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (KOSMOS-2 模型)
- layoutlm — LayoutLMTokenizer 或 LayoutLMTokenizerFast (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2Tokenizer 或 LayoutLMv2TokenizerFast (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3Tokenizer 或 LayoutLMv3TokenizerFast (LayoutLMv3 模型)
- layoutxlm — LayoutXLMTokenizer 或 LayoutXLMTokenizerFast (LayoutXLM 模型)
- led — LEDTokenizer 或 LEDTokenizerFast (LED 模型)
- lilt — LayoutLMv3Tokenizer 或 LayoutLMv3TokenizerFast (LiLT 模型)
- llama — LlamaTokenizer 或 LlamaTokenizerFast (LLaMA 模型)
- llava — LlamaTokenizer 或 LlamaTokenizerFast (LLaVa 模型)
- llava_next — LlamaTokenizer 或 LlamaTokenizerFast (LLaVA-NeXT 模型)
- llava_next_video — LlamaTokenizer 或 LlamaTokenizerFast (LLaVa-NeXT-Video 模型)
- llava_onevision — LlamaTokenizer 或 LlamaTokenizerFast (LLaVA-Onevision 模型)
- longformer — LongformerTokenizer 或 LongformerTokenizerFast (Longformer 模型)
- longt5 — T5Tokenizer 或 T5TokenizerFast (LongT5 模型)
- luke — LukeTokenizer (LUKE 模型)
- lxmert — LxmertTokenizer 或 LxmertTokenizerFast (LXMERT 模型)
- m2m_100 — M2M100Tokenizer (M2M100 模型)
- mamba — GPTNeoXTokenizerFast (Mamba 模型)
- mamba2 — GPTNeoXTokenizerFast (mamba2 模型)
- marian — MarianTokenizer (Marian 模型)
- mbart — MBartTokenizer 或 MBartTokenizerFast (mBART 模型)
- mbart50 — MBart50Tokenizer 或 MBart50TokenizerFast (mBART-50 模型)
- mega — RobertaTokenizer 或 RobertaTokenizerFast (MEGA 模型)
- megatron-bert — BertTokenizer 或 BertTokenizerFast (Megatron-BERT 模型)
- mgp-str — MgpstrTokenizer (MGP-STR 模型)
- mistral — LlamaTokenizer 或 LlamaTokenizerFast (Mistral 模型)
- mixtral — LlamaTokenizer 或 LlamaTokenizerFast (Mixtral 模型)
- mllama — LlamaTokenizer 或 LlamaTokenizerFast (Mllama 模型)
- mluke — MLukeTokenizer (mLUKE 模型)
- mobilebert — MobileBertTokenizer 或 MobileBertTokenizerFast (MobileBERT 模型)
- moshi — PreTrainedTokenizerFast (Moshi 模型)
- mpnet — MPNetTokenizer 或 MPNetTokenizerFast (MPNet 模型)
- mpt — GPTNeoXTokenizerFast (MPT 模型)
- mra — RobertaTokenizer 或 RobertaTokenizerFast (MRA 模型)
- mt5 — MT5Tokenizer 或 MT5TokenizerFast (MT5 模型)
- musicgen — T5Tokenizer 或 T5TokenizerFast (MusicGen 模型)
- musicgen_melody — T5Tokenizer 或 T5TokenizerFast (MusicGen Melody 模型)
- mvp — MvpTokenizer 或 MvpTokenizerFast (MVP 模型)
- myt5 — MyT5Tokenizer (myt5 模型)
- nezha — BertTokenizer 或 BertTokenizerFast (Nezha 模型)
- nllb — NllbTokenizer 或 NllbTokenizerFast (NLLB 模型)
- nllb-moe — NllbTokenizer 或 NllbTokenizerFast (NLLB-MOE 模型)
- nystromformer — AlbertTokenizer 或 AlbertTokenizerFast (Nyströmformer 模型)
- olmo — GPTNeoXTokenizerFast (OLMo 模型)
- olmo2 — GPTNeoXTokenizerFast (OLMo2 模型)
- olmoe — GPTNeoXTokenizerFast (OLMoE 模型)
- omdet-turbo — CLIPTokenizer 或 CLIPTokenizerFast (OmDet-Turbo 模型)
- oneformer — CLIPTokenizer 或 CLIPTokenizerFast (OneFormer 模型)
- openai-gpt — OpenAIGPTTokenizer 或 OpenAIGPTTokenizerFast (OpenAI GPT 模型)
- opt — GPT2Tokenizer 或 GPT2TokenizerFast (OPT 模型)
- owlv2 — CLIPTokenizer 或 CLIPTokenizerFast (OWLv2 模型)
- owlvit — CLIPTokenizer 或 CLIPTokenizerFast (OWL-ViT 模型)
- paligemma — LlamaTokenizer 或 LlamaTokenizerFast (PaliGemma 模型)
- pegasus — PegasusTokenizer 或 PegasusTokenizerFast (Pegasus 模型)
- pegasus_x — PegasusTokenizer 或 PegasusTokenizerFast (PEGASUS-X 模型)
- perceiver — PerceiverTokenizer (Perceiver 模型)
- persimmon — LlamaTokenizer 或 LlamaTokenizerFast (Persimmon 模型)
- phi — CodeGenTokenizer 或 CodeGenTokenizerFast (Phi 模型)
- phi3 — LlamaTokenizer 或 LlamaTokenizerFast (Phi3 模型)
- phimoe — LlamaTokenizer 或 LlamaTokenizerFast (Phimoe 模型)
- phobert — PhobertTokenizer (PhoBERT 模型)
- pix2struct — T5Tokenizer 或 T5TokenizerFast (Pix2Struct 模型)
- pixtral — PreTrainedTokenizerFast (Pixtral 模型)
- plbart — PLBartTokenizer (PLBart 模型)
- prophetnet — ProphetNetTokenizer (ProphetNet 模型)
- qdqbert — BertTokenizer 或 BertTokenizerFast (QDQBert 模型)
- qwen2 — Qwen2Tokenizer 或 Qwen2TokenizerFast (Qwen2 模型)
- qwen2_audio — Qwen2Tokenizer 或 Qwen2TokenizerFast (Qwen2Audio 模型)
- qwen2_moe — Qwen2Tokenizer 或 Qwen2TokenizerFast (Qwen2MoE 模型)
- qwen2_vl — Qwen2Tokenizer 或 Qwen2TokenizerFast (Qwen2VL 模型)
- rag — RagTokenizer (RAG 模型)
- realm — RealmTokenizer 或 RealmTokenizerFast (REALM 模型)
- recurrent_gemma — GemmaTokenizer 或 GemmaTokenizerFast (RecurrentGemma 模型)
- reformer — ReformerTokenizer 或 ReformerTokenizerFast (Reformer 模型)
- rembert — RemBertTokenizer 或 RemBertTokenizerFast (RemBERT 模型)
- retribert — RetriBertTokenizer 或 RetriBertTokenizerFast (RetriBERT 模型)
- roberta — RobertaTokenizer 或 RobertaTokenizerFast (RoBERTa 模型)
- roberta-prelayernorm — RobertaTokenizer 或 RobertaTokenizerFast (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertTokenizer (RoCBert 模型)
- roformer — RoFormerTokenizer 或 RoFormerTokenizerFast (RoFormer 模型)
- rwkv — GPTNeoXTokenizerFast (RWKV 模型)
- seamless_m4t — SeamlessM4TTokenizer 或 SeamlessM4TTokenizerFast (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4TTokenizer 或 SeamlessM4TTokenizerFast (SeamlessM4Tv2 模型)
- siglip — SiglipTokenizer (SigLIP 模型)
- speech_to_text — Speech2TextTokenizer (Speech2Text 模型)
- speech_to_text_2 — Speech2Text2Tokenizer (Speech2Text2 模型)
- speecht5 — SpeechT5Tokenizer (SpeechT5 模型)
- splinter — SplinterTokenizer 或 SplinterTokenizerFast (Splinter 模型)
- squeezebert — SqueezeBertTokenizer 或 SqueezeBertTokenizerFast (SqueezeBERT 模型)
- stablelm — GPTNeoXTokenizerFast (StableLm 模型)
- starcoder2 — GPT2Tokenizer 或 GPT2TokenizerFast (Starcoder2 模型)
- switch_transformers — T5Tokenizer 或 T5TokenizerFast (SwitchTransformers 模型)
- t5 — T5Tokenizer 或 T5TokenizerFast (T5 模型)
- tapas — TapasTokenizer (TAPAS 模型)
- tapex — TapexTokenizer (TAPEX 模型)
- transfo-xl — TransfoXLTokenizer (Transformer-XL 模型)
- tvp — BertTokenizer 或 BertTokenizerFast (TVP 模型)
- udop — UdopTokenizer 或 UdopTokenizerFast (UDOP 模型)
- umt5 — T5Tokenizer 或 T5TokenizerFast (UMT5 模型)
- video_llava — LlamaTokenizer 或 LlamaTokenizerFast (VideoLlava 模型)
- vilt — BertTokenizer 或 BertTokenizerFast (ViLT 模型)
- vipllava — LlamaTokenizer 或 LlamaTokenizerFast (VipLlava 模型)
- visual_bert — BertTokenizer 或 BertTokenizerFast (VisualBERT 模型)
- vits — VitsTokenizer (VITS 模型)
- wav2vec2 — Wav2Vec2CTCTokenizer (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2CTCTokenizer (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2CTCTokenizer (Wav2Vec2-Conformer 模型)
- wav2vec2_phoneme — Wav2Vec2PhonemeCTCTokenizer (Wav2Vec2Phoneme 模型)
- whisper — WhisperTokenizer 或 WhisperTokenizerFast (Whisper 模型)
- xclip — CLIPTokenizer 或 CLIPTokenizerFast (X-CLIP 模型)
- xglm — XGLMTokenizer 或 XGLMTokenizerFast (XGLM 模型)
- xlm — XLMTokenizer (XLM 模型)
- xlm-prophetnet — XLMProphetNetTokenizer (XLM-ProphetNet 模型)
- xlm-roberta — XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (XLM-RoBERTa-XL 模型)
- xlnet — XLNetTokenizer 或 XLNetTokenizerFast (XLNet 模型)
- xmod — XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (X-MOD 模型)
- yoso — AlbertTokenizer 或 AlbertTokenizerFast (YOSO 模型)
- zamba — LlamaTokenizer 或 LlamaTokenizerFast (Zamba 模型)
示例:
>>> from transformers import AutoTokenizer
>>> # Download vocabulary from huggingface.co and cache.
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
>>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
>>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
>>> # Download vocabulary from huggingface.co and define model-specific arguments
>>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)
注册
< source >( config_class slow_tokenizer_class = None fast_tokenizer_class = None exist_ok = False )
参数
- config_class (PretrainedConfig) — 与要注册的模型对应的配置。
- slow_tokenizer_class (
PretrainedTokenizer
, optional) — 要注册的慢速分词器。 - fast_tokenizer_class (
PretrainedTokenizerFast
, optional) — 要注册的快速分词器。
在此映射中注册一个新的分词器。
AutoFeatureExtractor
这是一个通用的特征提取器类,当使用AutoFeatureExtractor.from_pretrained()类方法创建时,它将作为库中的一个特征提取器类实例化。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 这可以是以下之一:- 一个字符串,表示托管在 huggingface.co 上的模型仓库中的预训练特征提取器的 模型 id。
- 一个包含使用 save_pretrained() 方法保存的特征提取器文件的 目录 的路径,例如,
./my_model_directory/
。 - 一个保存的特征提取器 JSON 文件 的路径或 URL,例如,
./my_model_directory/preprocessor_config.json
。
- cache_dir (
str
oros.PathLike
, optional) — 如果不应使用标准缓存,则应缓存下载的预训练模型特征提取器的目录路径。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载特征提取器文件并覆盖缓存版本(如果存在)。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
这些代理在每个请求中使用。 - token (
str
or bool, optional) — 用于远程文件的HTTP承载授权的令牌。如果为True
,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - return_unused_kwargs (
bool
, 可选, 默认为False
) — 如果为False
,则此函数仅返回最终的特征提取器对象。如果为True
,则此 函数返回一个Tuple(feature_extractor, unused_kwargs)
,其中 unused_kwargs 是一个字典, 包含那些键不是特征提取器属性的键/值对:即,kwargs
中未用于更新feature_extractor
的部分, 并且被忽略。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - kwargs (
Dict[str, Any]
, 可选) — kwargs 中任何键的值如果是特征提取器属性,将用于覆盖加载的值。关于键/值对中键不是特征提取器属性的行为由return_unused_kwargs
关键字参数控制。
从预训练模型词汇表中实例化库中的一个特征提取器类。
特征提取器类的实例化是基于配置对象的model_type
属性选择的
(无论是作为参数传递还是从pretrained_model_name_or_path
加载,如果可能的话),或者当它
缺失时,通过回退到使用pretrained_model_name_or_path
上的模式匹配来选择:
- audio-spectrogram-transformer — ASTFeatureExtractor (音频频谱变换器模型)
- beit — BeitFeatureExtractor (BEiT 模型)
- chinese_clip — ChineseCLIPFeatureExtractor (中文-CLIP 模型)
- clap — ClapFeatureExtractor (CLAP 模型)
- clip — CLIPFeatureExtractor (CLIP 模型)
- clipseg — ViTFeatureExtractor (CLIPSeg 模型)
- clvp — ClvpFeatureExtractor (CLVP 模型)
- conditional_detr — ConditionalDetrFeatureExtractor (条件DETR模型)
- convnext — ConvNextFeatureExtractor (ConvNeXT 模型)
- cvt — ConvNextFeatureExtractor (CvT 模型)
- dac — DacFeatureExtractor (DAC 模型)
- data2vec-audio — Wav2Vec2FeatureExtractor (Data2VecAudio 模型)
- data2vec-vision — BeitFeatureExtractor (Data2VecVision 模型)
- deformable_detr — DeformableDetrFeatureExtractor (可变形DETR模型)
- deit — DeiTFeatureExtractor (DeiT 模型)
- detr — DetrFeatureExtractor (DETR 模型)
- dinat — ViTFeatureExtractor (DiNAT 模型)
- donut-swin — DonutFeatureExtractor (DonutSwin 模型)
- dpt — DPTFeatureExtractor (DPT 模型)
- encodec — EncodecFeatureExtractor (EnCodec 模型)
- flava — FlavaFeatureExtractor (FLAVA 模型)
- glpn — GLPNFeatureExtractor (GLPN 模型)
- groupvit — CLIPFeatureExtractor (GroupViT 模型)
- hubert — Wav2Vec2FeatureExtractor (Hubert 模型)
- imagegpt — ImageGPTFeatureExtractor (ImageGPT 模型)
- layoutlmv2 — LayoutLMv2FeatureExtractor (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3FeatureExtractor (LayoutLMv3 模型)
- levit — LevitFeatureExtractor (LeViT 模型)
- maskformer — MaskFormerFeatureExtractor (MaskFormer 模型)
- mctct — MCTCTFeatureExtractor (M-CTC-T 模型)
- mimi — EncodecFeatureExtractor (Mimi 模型)
- mobilenet_v1 — MobileNetV1FeatureExtractor (MobileNetV1 模型)
- mobilenet_v2 — MobileNetV2FeatureExtractor (MobileNetV2 模型)
- mobilevit — MobileViTFeatureExtractor (MobileViT 模型)
- moshi — EncodecFeatureExtractor (Moshi 模型)
- nat — ViTFeatureExtractor (NAT 模型)
- owlvit — OwlViTFeatureExtractor (OWL-ViT 模型)
- perceiver — PerceiverFeatureExtractor (Perceiver 模型)
- poolformer — PoolFormerFeatureExtractor (PoolFormer 模型)
- pop2piano — Pop2PianoFeatureExtractor (Pop2Piano 模型)
- regnet — ConvNextFeatureExtractor (RegNet 模型)
- resnet — ConvNextFeatureExtractor (ResNet 模型)
- seamless_m4t — SeamlessM4TFeatureExtractor (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4TFeatureExtractor (SeamlessM4Tv2 模型)
- segformer — SegformerFeatureExtractor (SegFormer 模型)
- sew — Wav2Vec2FeatureExtractor (SEW 模型)
- sew-d — Wav2Vec2FeatureExtractor (SEW-D 模型)
- speech_to_text — Speech2TextFeatureExtractor (Speech2Text 模型)
- speecht5 — SpeechT5FeatureExtractor (SpeechT5 模型)
- swiftformer — ViTFeatureExtractor (SwiftFormer 模型)
- swin — ViTFeatureExtractor (Swin Transformer 模型)
- swinv2 — ViTFeatureExtractor (Swin Transformer V2 模型)
- table-transformer — DetrFeatureExtractor (Table Transformer 模型)
- timesformer — VideoMAEFeatureExtractor (TimeSformer 模型)
- tvlt — TvltFeatureExtractor (TVLT 模型)
- unispeech — Wav2Vec2FeatureExtractor (UniSpeech 模型)
- unispeech-sat — Wav2Vec2FeatureExtractor (UniSpeechSat 模型)
- univnet — UnivNetFeatureExtractor (UnivNet 模型)
- van — ConvNextFeatureExtractor (VAN 模型)
- videomae — VideoMAEFeatureExtractor (VideoMAE 模型)
- vilt — ViltFeatureExtractor (ViLT 模型)
- vit — ViTFeatureExtractor (ViT 模型)
- vit_mae — ViTFeatureExtractor (ViTMAE 模型)
- vit_msn — ViTFeatureExtractor (ViTMSN 模型)
- wav2vec2 — Wav2Vec2FeatureExtractor (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2FeatureExtractor (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2FeatureExtractor (Wav2Vec2-Conformer 模型)
- wavlm — Wav2Vec2FeatureExtractor (WavLM 模型)
- whisper — WhisperFeatureExtractor (Whisper 模型)
- xclip — CLIPFeatureExtractor (X-CLIP 模型)
- yolos — YolosFeatureExtractor (YOLOS 模型)
当您想要使用私有模型时,需要传递token=True
。
示例:
>>> from transformers import AutoFeatureExtractor
>>> # Download feature extractor from huggingface.co and cache.
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
>>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*)
>>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
注册
< source >( config_class feature_extractor_class exist_ok = False )
参数
- config_class (PretrainedConfig) — 与要注册的模型对应的配置。
- feature_extractor_class (
FeatureExtractorMixin
) — 要注册的特征提取器。
为这个类注册一个新的特征提取器。
AutoImageProcessor
这是一个通用的图像处理器类,当使用AutoImageProcessor.from_pretrained()类方法创建时,它将作为库中的一个图像处理器类实例化。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_pretrained
< source >( pretrained_model_name_or_path *inputs **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 这可以是以下之一:- 一个字符串,表示托管在 huggingface.co 上的模型仓库中的预训练图像处理器的 模型 id。
- 一个路径,指向使用 save_pretrained() 方法保存的图像处理器文件的 目录,例如,
./my_model_directory/
。 - 一个路径或 URL,指向保存的图像处理器 JSON 文件,例如,
./my_model_directory/preprocessor_config.json
。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型图像处理器的目录路径。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载图像处理器文件并覆盖缓存版本(如果存在)。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
这些代理在每个请求时都会被使用。 - token (
str
或 bool, 可选) — 用于远程文件的HTTP承载授权的令牌。如果为True
,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - use_fast (
bool
, 可选, 默认为False
) — 如果给定模型支持,使用基于torchvision的快速图像处理器。 如果给定模型没有可用的快速分词器,则返回基于numpy的普通图像处理器。 - return_unused_kwargs (
bool
, 可选, 默认为False
) — 如果为False
,则此函数仅返回最终的图像处理器对象。如果为True
,则此 函数返回一个Tuple(image_processor, unused_kwargs)
,其中 unused_kwargs 是一个字典, 包含那些键不是图像处理器属性的键/值对:即,kwargs
中未用于更新image_processor
的部分, 并且被忽略的部分。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的存储库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - kwargs (
Dict[str, Any]
, 可选) — kwargs 中任何键的值如果是图像处理器属性,将用于覆盖加载的值。关于键/值对中键不是图像处理器属性的行为由return_unused_kwargs
关键字参数控制。
从预训练模型词汇表中实例化库中的一个图像处理器类。
要实例化的图像处理器类是根据配置对象的model_type
属性选择的
(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当它
缺失时,通过回退到使用pretrained_model_name_or_path
上的模式匹配来选择:
- align — EfficientNetImageProcessor (ALIGN 模型)
- beit — BeitImageProcessor (BEiT 模型)
- bit — BitImageProcessor (BiT 模型)
- blip — BlipImageProcessor (BLIP 模型)
- blip-2 — BlipImageProcessor (BLIP-2 模型)
- bridgetower — BridgeTowerImageProcessor (BridgeTower 模型)
- chameleon — ChameleonImageProcessor (变色龙模型)
- chinese_clip — ChineseCLIPImageProcessor (中文-CLIP 模型)
- clip — CLIPImageProcessor (CLIP 模型)
- clipseg — ViTImageProcessor 或 ViTImageProcessorFast (CLIPSeg 模型)
- conditional_detr — ConditionalDetrImageProcessor (条件DETR模型)
- convnext — ConvNextImageProcessor (ConvNeXT 模型)
- convnextv2 — ConvNextImageProcessor (ConvNeXTV2 模型)
- cvt — ConvNextImageProcessor (CvT 模型)
- data2vec-vision — BeitImageProcessor (Data2VecVision 模型)
- deformable_detr — DeformableDetrImageProcessor 或 DeformableDetrImageProcessorFast (Deformable DETR 模型)
- deit — DeiTImageProcessor (DeiT 模型)
- depth_anything — DPTImageProcessor (Depth Anything 模型)
- deta — DetaImageProcessor (DETA 模型)
- detr — DetrImageProcessor 或 DetrImageProcessorFast (DETR 模型)
- dinat — ViTImageProcessor 或 ViTImageProcessorFast (DiNAT 模型)
- dinov2 — BitImageProcessor (DINOv2 模型)
- donut-swin — DonutImageProcessor (DonutSwin 模型)
- dpt — DPTImageProcessor (DPT 模型)
- efficientformer — EfficientFormerImageProcessor (EfficientFormer 模型)
- efficientnet — EfficientNetImageProcessor (EfficientNet 模型)
- flava — FlavaImageProcessor (FLAVA 模型)
- focalnet — BitImageProcessor (FocalNet 模型)
- fuyu — FuyuImageProcessor (Fuyu 模型)
- git — CLIPImageProcessor (GIT 模型)
- glpn — GLPNImageProcessor (GLPN 模型)
- grounding-dino — GroundingDinoImageProcessor (Grounding DINO 模型)
- groupvit — CLIPImageProcessor (GroupViT 模型)
- hiera — BitImageProcessor (Hiera 模型)
- idefics — IdeficsImageProcessor (IDEFICS 模型)
- idefics2 — Idefics2ImageProcessor (Idefics2 模型)
- idefics3 — Idefics3ImageProcessor (Idefics3 模型)
- ijepa — ViTImageProcessor 或 ViTImageProcessorFast (I-JEPA 模型)
- imagegpt — ImageGPTImageProcessor (ImageGPT 模型)
- instructblip — BlipImageProcessor (InstructBLIP 模型)
- instructblipvideo — InstructBlipVideoImageProcessor (InstructBlipVideo 模型)
- kosmos-2 — CLIPImageProcessor (KOSMOS-2 模型)
- layoutlmv2 — LayoutLMv2ImageProcessor (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3ImageProcessor (LayoutLMv3 模型)
- levit — LevitImageProcessor (LeViT 模型)
- llava — CLIPImageProcessor (LLaVa 模型)
- llava_next — LlavaNextImageProcessor (LLaVA-NeXT 模型)
- llava_next_video — LlavaNextVideoImageProcessor (LLaVa-NeXT-Video 模型)
- llava_onevision — LlavaOnevisionImageProcessor (LLaVA-Onevision 模型)
- mask2former — Mask2FormerImageProcessor (Mask2Former 模型)
- maskformer — MaskFormerImageProcessor (MaskFormer 模型)
- mgp-str — ViTImageProcessor 或 ViTImageProcessorFast (MGP-STR 模型)
- mllama — MllamaImageProcessor (Mllama 模型)
- mobilenet_v1 — MobileNetV1ImageProcessor (MobileNetV1 模型)
- mobilenet_v2 — MobileNetV2ImageProcessor (MobileNetV2 模型)
- mobilevit — MobileViTImageProcessor (MobileViT 模型)
- mobilevitv2 — MobileViTImageProcessor (MobileViTV2 模型)
- nat — ViTImageProcessor 或 ViTImageProcessorFast (NAT 模型)
- nougat — NougatImageProcessor (Nougat 模型)
- oneformer — OneFormerImageProcessor (OneFormer 模型)
- owlv2 — Owlv2ImageProcessor (OWLv2 模型)
- owlvit — OwlViTImageProcessor (OWL-ViT 模型)
- paligemma — SiglipImageProcessor (PaliGemma 模型)
- perceiver — PerceiverImageProcessor (Perceiver 模型)
- pix2struct — Pix2StructImageProcessor (Pix2Struct 模型)
- pixtral — PixtralImageProcessor 或 PixtralImageProcessorFast (Pixtral 模型)
- poolformer — PoolFormerImageProcessor (PoolFormer 模型)
- pvt — PvtImageProcessor (PVT 模型)
- pvt_v2 — PvtImageProcessor (PVTv2 模型)
- qwen2_vl — Qwen2VLImageProcessor (Qwen2VL 模型)
- regnet — ConvNextImageProcessor (RegNet 模型)
- resnet — ConvNextImageProcessor (ResNet 模型)
- rt_detr — RTDetrImageProcessor 或 RTDetrImageProcessorFast (RT-DETR 模型)
- sam — SamImageProcessor (SAM 模型)
- segformer — SegformerImageProcessor (SegFormer 模型)
- seggpt — SegGptImageProcessor (SegGPT 模型)
- siglip — SiglipImageProcessor (SigLIP 模型)
- swiftformer — ViTImageProcessor 或 ViTImageProcessorFast (SwiftFormer 模型)
- swin — ViTImageProcessor 或 ViTImageProcessorFast (Swin Transformer 模型)
- swin2sr — Swin2SRImageProcessor (Swin2SR 模型)
- swinv2 — ViTImageProcessor 或 ViTImageProcessorFast (Swin Transformer V2 模型)
- table-transformer — DetrImageProcessor (Table Transformer 模型)
- timesformer — VideoMAEImageProcessor (TimeSformer 模型)
- tvlt — TvltImageProcessor (TVLT 模型)
- tvp — TvpImageProcessor (TVP 模型)
- udop — LayoutLMv3ImageProcessor (UDOP 模型)
- upernet — SegformerImageProcessor (UPerNet 模型)
- van — ConvNextImageProcessor (VAN 模型)
- videomae — VideoMAEImageProcessor (VideoMAE 模型)
- vilt — ViltImageProcessor (ViLT 模型)
- vipllava — CLIPImageProcessor (VipLlava 模型)
- vit — ViTImageProcessor 或 ViTImageProcessorFast (ViT 模型)
- vit_hybrid — ViTHybridImageProcessor (ViT 混合模型)
- vit_mae — ViTImageProcessor 或 ViTImageProcessorFast (ViTMAE 模型)
- vit_msn — ViTImageProcessor 或 ViTImageProcessorFast (ViTMSN 模型)
- vitmatte — VitMatteImageProcessor (ViTMatte 模型)
- xclip — CLIPImageProcessor (X-CLIP 模型)
- yolos — YolosImageProcessor (YOLOS 模型)
- zoedepth — ZoeDepthImageProcessor (ZoeDepth 模型)
当您想要使用私有模型时,需要传递token=True
。
示例:
>>> from transformers import AutoImageProcessor
>>> # Download image processor from huggingface.co and cache.
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*)
>>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
注册
< source >( config_class image_processor_class = 无 slow_image_processor_class = 无 fast_image_processor_class = 无 exist_ok = 假 )
参数
- config_class (PretrainedConfig) — 与要注册的模型对应的配置。
- image_processor_class (ImageProcessingMixin) — 要注册的图像处理器.
为此类注册一个新的图像处理器。
AutoProcessor
这是一个通用处理器类,当使用AutoProcessor.from_pretrained()类方法创建时,它将实例化为库中的一个处理器类。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 这可以是以下之一:- 一个字符串,表示托管在 huggingface.co 上的模型仓库中的预训练特征提取器的 模型 id。
- 一个路径,指向使用
save_pretrained()
方法保存的处理器文件的 目录, 例如,./my_model_directory/
。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型特征提取器的目录路径。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载特征提取器文件并覆盖缓存版本(如果存在)。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
这些代理在每个请求时都会被使用。 - token (
str
或 bool, 可选) — 用于远程文件的HTTP承载授权的令牌。如果为True
,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - return_unused_kwargs (
bool
, 可选, 默认为False
) — 如果为False
,则此函数仅返回最终的特征提取器对象。如果为True
,则此 函数返回一个Tuple(feature_extractor, unused_kwargs)
,其中 unused_kwargs 是一个字典, 包含那些键不是特征提取器属性的键/值对:即,kwargs
中未用于更新feature_extractor
的部分, 并且被忽略。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - kwargs (
Dict[str, Any]
, 可选) — kwargs 中任何键的值如果是特征提取器属性,将用于覆盖加载的值。关于键/值对中键不是特征提取器属性的行为由return_unused_kwargs
关键字参数控制。
从预训练模型词汇表中实例化库中的一个处理器类。
要实例化的处理器类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载):
- align — AlignProcessor (ALIGN 模型)
- altclip — AltCLIPProcessor (AltCLIP 模型)
- bark — BarkProcessor (Bark 模型)
- blip — BlipProcessor (BLIP 模型)
- blip-2 — Blip2Processor (BLIP-2 模型)
- bridgetower — BridgeTowerProcessor (BridgeTower 模型)
- chameleon — ChameleonProcessor (变色龙模型)
- chinese_clip — ChineseCLIPProcessor (中文-CLIP 模型)
- clap — ClapProcessor (CLAP 模型)
- clip — CLIPProcessor (CLIP 模型)
- clipseg — CLIPSegProcessor (CLIPSeg 模型)
- clvp — ClvpProcessor (CLVP 模型)
- flava — FlavaProcessor (FLAVA 模型)
- fuyu — FuyuProcessor (Fuyu 模型)
- git — GitProcessor (GIT 模型)
- grounding-dino — GroundingDinoProcessor (Grounding DINO 模型)
- groupvit — CLIPProcessor (GroupViT 模型)
- hubert — Wav2Vec2Processor (Hubert 模型)
- idefics — IdeficsProcessor (IDEFICS 模型)
- idefics2 — Idefics2Processor (Idefics2 模型)
- idefics3 — Idefics3Processor (Idefics3 模型)
- instructblip — InstructBlipProcessor (InstructBLIP 模型)
- instructblipvideo — InstructBlipVideoProcessor (InstructBlipVideo 模型)
- kosmos-2 — Kosmos2Processor (KOSMOS-2 模型)
- layoutlmv2 — LayoutLMv2Processor (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3Processor (LayoutLMv3 模型)
- llava — LlavaProcessor (LLaVa 模型)
- llava_next — LlavaNextProcessor (LLaVA-NeXT 模型)
- llava_next_video — LlavaNextVideoProcessor (LLaVa-NeXT-Video 模型)
- llava_onevision — LlavaOnevisionProcessor (LLaVA-Onevision 模型)
- markuplm — MarkupLMProcessor (MarkupLM 模型)
- mctct — MCTCTProcessor (M-CTC-T 模型)
- mgp-str — MgpstrProcessor (MGP-STR 模型)
- mllama — MllamaProcessor (Mllama 模型)
- oneformer — OneFormerProcessor (OneFormer 模型)
- owlv2 — Owlv2Processor (OWLv2 模型)
- owlvit — OwlViTProcessor (OWL-ViT 模型)
- paligemma — PaliGemmaProcessor (PaliGemma 模型)
- pix2struct — Pix2StructProcessor (Pix2Struct 模型)
- pixtral — PixtralProcessor (Pixtral 模型)
- pop2piano — Pop2PianoProcessor (Pop2Piano 模型)
- qwen2_audio — Qwen2AudioProcessor (Qwen2Audio 模型)
- qwen2_vl — Qwen2VLProcessor (Qwen2VL 模型)
- sam — SamProcessor (SAM 模型)
- seamless_m4t — SeamlessM4TProcessor (SeamlessM4T 模型)
- sew — Wav2Vec2Processor (SEW 模型)
- sew-d — Wav2Vec2Processor (SEW-D 模型)
- siglip — SiglipProcessor (SigLIP 模型)
- speech_to_text — Speech2TextProcessor (语音转文本模型)
- speech_to_text_2 — Speech2Text2Processor (Speech2Text2 模型)
- speecht5 — SpeechT5Processor (SpeechT5 模型)
- trocr — TrOCRProcessor (TrOCR 模型)
- tvlt — TvltProcessor (TVLT 模型)
- tvp — TvpProcessor (TVP 模型)
- udop — UdopProcessor (UDOP 模型)
- unispeech — Wav2Vec2Processor (UniSpeech 模型)
- unispeech-sat — Wav2Vec2Processor (UniSpeechSat 模型)
- video_llava — VideoLlavaProcessor (VideoLlava 模型)
- vilt — ViltProcessor (ViLT 模型)
- vipllava — LlavaProcessor (VipLlava 模型)
- vision-text-dual-encoder — VisionTextDualEncoderProcessor (VisionTextDualEncoder 模型)
- wav2vec2 — Wav2Vec2Processor (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2Processor (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2Processor (Wav2Vec2-Conformer 模型)
- wavlm — Wav2Vec2Processor (WavLM 模型)
- whisper — WhisperProcessor (Whisper 模型)
- xclip — XCLIPProcessor (X-CLIP 模型)
当您想要使用私有模型时,需要传递token=True
。
示例:
>>> from transformers import AutoProcessor
>>> # Download processor from huggingface.co and cache.
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
>>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*)
>>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
注册
< source >( config_class processor_class exist_ok = False )
参数
- config_class (PretrainedConfig) — 与要注册的模型对应的配置。
- processor_class (
FeatureExtractorMixin
) — 要注册的处理器.
为此类注册一个新的处理器。
通用模型类
以下自动类可用于实例化没有特定头的基础模型类。
AutoModel
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将作为库的基础模型类之一进行实例化。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- ASTConfig configuration class: ASTModel (Audio Spectrogram Transformer model)
- AlbertConfig configuration class: AlbertModel (ALBERT model)
- AlignConfig configuration class: AlignModel (ALIGN model)
- AltCLIPConfig configuration class: AltCLIPModel (AltCLIP model)
- AutoformerConfig configuration class: AutoformerModel (Autoformer model)
- BarkConfig configuration class: BarkModel (Bark model)
- BartConfig configuration class: BartModel (BART model)
- BeitConfig configuration class: BeitModel (BEiT model)
- BertConfig configuration class: BertModel (BERT model)
- BertGenerationConfig configuration class: BertGenerationEncoder (Bert Generation model)
- BigBirdConfig configuration class: BigBirdModel (BigBird model)
- BigBirdPegasusConfig configuration class: BigBirdPegasusModel (BigBird-Pegasus model)
- BioGptConfig configuration class: BioGptModel (BioGpt model)
- BitConfig configuration class: BitModel (BiT model)
- BlenderbotConfig configuration class: BlenderbotModel (Blenderbot model)
- BlenderbotSmallConfig configuration class: BlenderbotSmallModel (BlenderbotSmall model)
- Blip2Config configuration class: Blip2Model (BLIP-2 model)
- BlipConfig configuration class: BlipModel (BLIP model)
- BloomConfig configuration class: BloomModel (BLOOM model)
- BridgeTowerConfig configuration class: BridgeTowerModel (BridgeTower model)
- BrosConfig configuration class: BrosModel (BROS model)
- CLIPConfig configuration class: CLIPModel (CLIP model)
- CLIPSegConfig configuration class: CLIPSegModel (CLIPSeg model)
- CLIPTextConfig configuration class: CLIPTextModel (CLIPTextModel model)
- CLIPVisionConfig configuration class: CLIPVisionModel (CLIPVisionModel model)
- CTRLConfig configuration class: CTRLModel (CTRL model)
- CamembertConfig configuration class: CamembertModel (CamemBERT model)
- CanineConfig configuration class: CanineModel (CANINE model)
- ChameleonConfig configuration class: ChameleonModel (Chameleon model)
- ChineseCLIPConfig configuration class: ChineseCLIPModel (Chinese-CLIP model)
- ChineseCLIPVisionConfig configuration class: ChineseCLIPVisionModel (ChineseCLIPVisionModel model)
- ClapConfig configuration class: ClapModel (CLAP model)
- ClvpConfig configuration class: ClvpModelForConditionalGeneration (CLVP model)
- CodeGenConfig configuration class: CodeGenModel (CodeGen model)
- CohereConfig configuration class: CohereModel (Cohere model)
- ConditionalDetrConfig configuration class: ConditionalDetrModel (Conditional DETR model)
- ConvBertConfig configuration class: ConvBertModel (ConvBERT model)
- ConvNextConfig configuration class: ConvNextModel (ConvNeXT model)
- ConvNextV2Config configuration class: ConvNextV2Model (ConvNeXTV2 model)
- CpmAntConfig configuration class: CpmAntModel (CPM-Ant model)
- CvtConfig configuration class: CvtModel (CvT model)
- DPRConfig configuration class: DPRQuestionEncoder (DPR model)
- DPTConfig configuration class: DPTModel (DPT model)
- DacConfig configuration class: DacModel (DAC model)
- Data2VecAudioConfig configuration class: Data2VecAudioModel (Data2VecAudio model)
- Data2VecTextConfig configuration class: Data2VecTextModel (Data2VecText model)
- Data2VecVisionConfig configuration class: Data2VecVisionModel (Data2VecVision model)
- DbrxConfig configuration class: DbrxModel (DBRX model)
- DebertaConfig configuration class: DebertaModel (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2Model (DeBERTa-v2 model)
- DecisionTransformerConfig configuration class: DecisionTransformerModel (Decision Transformer model)
- DeformableDetrConfig configuration class: DeformableDetrModel (Deformable DETR model)
- DeiTConfig configuration class: DeiTModel (DeiT model)
- DetaConfig configuration class: DetaModel (DETA model)
- DetrConfig configuration class: DetrModel (DETR model)
- DinatConfig configuration class: DinatModel (DiNAT model)
- Dinov2Config configuration class: Dinov2Model (DINOv2 model)
- DistilBertConfig configuration class: DistilBertModel (DistilBERT model)
- DonutSwinConfig configuration class: DonutSwinModel (DonutSwin model)
- EfficientFormerConfig configuration class: EfficientFormerModel (EfficientFormer model)
- EfficientNetConfig configuration class: EfficientNetModel (EfficientNet model)
- ElectraConfig configuration class: ElectraModel (ELECTRA model)
- EncodecConfig configuration class: EncodecModel (EnCodec model)
- ErnieConfig configuration class: ErnieModel (ERNIE model)
- ErnieMConfig configuration class: ErnieMModel (ErnieM model)
- EsmConfig configuration class: EsmModel (ESM model)
- FNetConfig configuration class: FNetModel (FNet model)
- FSMTConfig configuration class: FSMTModel (FairSeq Machine-Translation model)
- FalconConfig configuration class: FalconModel (Falcon model)
- FalconMambaConfig configuration class: FalconMambaModel (FalconMamba model)
- FastSpeech2ConformerConfig configuration class: FastSpeech2ConformerModel (FastSpeech2Conformer model)
- FlaubertConfig configuration class: FlaubertModel (FlauBERT model)
- FlavaConfig configuration class: FlavaModel (FLAVA model)
- FocalNetConfig configuration class: FocalNetModel (FocalNet model)
- FunnelConfig configuration class: FunnelModel or FunnelBaseModel (Funnel Transformer model)
- GLPNConfig configuration class: GLPNModel (GLPN model)
- GPT2Config configuration class: GPT2Model (OpenAI GPT-2 model)
- GPTBigCodeConfig configuration class: GPTBigCodeModel (GPTBigCode model)
- GPTJConfig configuration class: GPTJModel (GPT-J model)
- GPTNeoConfig configuration class: GPTNeoModel (GPT Neo model)
- GPTNeoXConfig configuration class: GPTNeoXModel (GPT NeoX model)
- GPTNeoXJapaneseConfig configuration class: GPTNeoXJapaneseModel (GPT NeoX Japanese model)
- GPTSanJapaneseConfig configuration class: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese model)
- Gemma2Config configuration class: Gemma2Model (Gemma2 model)
- GemmaConfig configuration class: GemmaModel (Gemma model)
- GitConfig configuration class: GitModel (GIT model)
- GlmConfig configuration class: GlmModel (GLM model)
- GraniteConfig configuration class: GraniteModel (Granite model)
- GraniteMoeConfig configuration class: GraniteMoeModel (GraniteMoeMoe model)
- GraphormerConfig configuration class: GraphormerModel (Graphormer model)
- GroundingDinoConfig configuration class: GroundingDinoModel (Grounding DINO model)
- GroupViTConfig configuration class: GroupViTModel (GroupViT model)
- HieraConfig configuration class: HieraModel (Hiera model)
- HubertConfig configuration class: HubertModel (Hubert model)
- IBertConfig configuration class: IBertModel (I-BERT model)
- IJepaConfig configuration class: IJepaModel (I-JEPA model)
- Idefics2Config configuration class: Idefics2Model (Idefics2 model)
- Idefics3Config configuration class: Idefics3Model (Idefics3 model)
- IdeficsConfig configuration class: IdeficsModel (IDEFICS model)
- ImageGPTConfig configuration class: ImageGPTModel (ImageGPT model)
- InformerConfig configuration class: InformerModel (Informer model)
- JambaConfig configuration class: JambaModel (Jamba model)
- JetMoeConfig configuration class: JetMoeModel (JetMoe model)
- JukeboxConfig configuration class: JukeboxModel (Jukebox model)
- Kosmos2Config configuration class: Kosmos2Model (KOSMOS-2 model)
- LEDConfig configuration class: LEDModel (LED model)
- LayoutLMConfig configuration class: LayoutLMModel (LayoutLM model)
- LayoutLMv2Config configuration class: LayoutLMv2Model (LayoutLMv2 model)
- LayoutLMv3Config configuration class: LayoutLMv3Model (LayoutLMv3 model)
- LevitConfig configuration class: LevitModel (LeViT model)
- LiltConfig configuration class: LiltModel (LiLT model)
- LlamaConfig configuration class: LlamaModel (LLaMA model)
- LongT5Config configuration class: LongT5Model (LongT5 model)
- LongformerConfig configuration class: LongformerModel (Longformer model)
- LukeConfig configuration class: LukeModel (LUKE model)
- LxmertConfig configuration class: LxmertModel (LXMERT model)
- M2M100Config configuration class: M2M100Model (M2M100 model)
- MBartConfig configuration class: MBartModel (mBART model)
- MCTCTConfig configuration class: MCTCTModel (M-CTC-T model)
- MPNetConfig configuration class: MPNetModel (MPNet model)
- MT5Config configuration class: MT5Model (MT5 model)
- Mamba2Config configuration class: Mamba2Model (mamba2 model)
- MambaConfig configuration class: MambaModel (Mamba model)
- MarianConfig configuration class: MarianModel (Marian model)
- MarkupLMConfig configuration class: MarkupLMModel (MarkupLM model)
- Mask2FormerConfig configuration class: Mask2FormerModel (Mask2Former model)
- MaskFormerConfig configuration class: MaskFormerModel (MaskFormer model)
MaskFormerSwinConfig
configuration class:MaskFormerSwinModel
(MaskFormerSwin model)- MegaConfig configuration class: MegaModel (MEGA model)
- MegatronBertConfig configuration class: MegatronBertModel (Megatron-BERT model)
- MgpstrConfig configuration class: MgpstrForSceneTextRecognition (MGP-STR model)
- MimiConfig configuration class: MimiModel (Mimi model)
- MistralConfig configuration class: MistralModel (Mistral model)
- MixtralConfig configuration class: MixtralModel (Mixtral model)
- MobileBertConfig configuration class: MobileBertModel (MobileBERT model)
- MobileNetV1Config configuration class: MobileNetV1Model (MobileNetV1 model)
- MobileNetV2Config configuration class: MobileNetV2Model (MobileNetV2 model)
- MobileViTConfig configuration class: MobileViTModel (MobileViT model)
- MobileViTV2Config configuration class: MobileViTV2Model (MobileViTV2 model)
- MoshiConfig configuration class: MoshiModel (Moshi model)
- MptConfig configuration class: MptModel (MPT model)
- MraConfig configuration class: MraModel (MRA model)
- MusicgenConfig configuration class: MusicgenModel (MusicGen model)
- MusicgenMelodyConfig configuration class: MusicgenMelodyModel (MusicGen Melody model)
- MvpConfig configuration class: MvpModel (MVP model)
- NatConfig configuration class: NatModel (NAT model)
- NemotronConfig configuration class: NemotronModel (Nemotron model)
- NezhaConfig configuration class: NezhaModel (Nezha model)
- NllbMoeConfig configuration class: NllbMoeModel (NLLB-MOE model)
- NystromformerConfig configuration class: NystromformerModel (Nyströmformer model)
- OPTConfig configuration class: OPTModel (OPT model)
- Olmo2Config configuration class: Olmo2Model (OLMo2 model)
- OlmoConfig configuration class: OlmoModel (OLMo model)
- OlmoeConfig configuration class: OlmoeModel (OLMoE model)
- OmDetTurboConfig configuration class: OmDetTurboForObjectDetection (OmDet-Turbo model)
- OneFormerConfig configuration class: OneFormerModel (OneFormer model)
- OpenAIGPTConfig configuration class: OpenAIGPTModel (OpenAI GPT model)
- OpenLlamaConfig configuration class: OpenLlamaModel (OpenLlama model)
- OwlViTConfig configuration class: OwlViTModel (OWL-ViT model)
- Owlv2Config configuration class: Owlv2Model (OWLv2 model)
- PLBartConfig configuration class: PLBartModel (PLBart model)
- PatchTSMixerConfig configuration class: PatchTSMixerModel (PatchTSMixer model)
- PatchTSTConfig configuration class: PatchTSTModel (PatchTST model)
- PegasusConfig configuration class: PegasusModel (Pegasus model)
- PegasusXConfig configuration class: PegasusXModel (PEGASUS-X model)
- PerceiverConfig configuration class: PerceiverModel (Perceiver model)
- PersimmonConfig configuration class: PersimmonModel (Persimmon model)
- Phi3Config configuration class: Phi3Model (Phi3 model)
- PhiConfig configuration class: PhiModel (Phi model)
- PhimoeConfig configuration class: PhimoeModel (Phimoe model)
- PixtralVisionConfig configuration class: PixtralVisionModel (Pixtral model)
- PoolFormerConfig configuration class: PoolFormerModel (PoolFormer model)
- ProphetNetConfig configuration class: ProphetNetModel (ProphetNet model)
- PvtConfig configuration class: PvtModel (PVT model)
- PvtV2Config configuration class: PvtV2Model (PVTv2 model)
- QDQBertConfig configuration class: QDQBertModel (QDQBert model)
- Qwen2AudioEncoderConfig configuration class:
Qwen2AudioEncoder
(Qwen2AudioEncoder model) - Qwen2Config configuration class: Qwen2Model (Qwen2 model)
- Qwen2MoeConfig configuration class: Qwen2MoeModel (Qwen2MoE model)
- Qwen2VLConfig configuration class: Qwen2VLModel (Qwen2VL model)
- RTDetrConfig configuration class: RTDetrModel (RT-DETR model)
- RecurrentGemmaConfig configuration class: RecurrentGemmaModel (RecurrentGemma model)
- ReformerConfig configuration class: ReformerModel (Reformer model)
- RegNetConfig configuration class: RegNetModel (RegNet model)
- RemBertConfig configuration class: RemBertModel (RemBERT model)
- ResNetConfig configuration class: ResNetModel (ResNet model)
- RetriBertConfig configuration class: RetriBertModel (RetriBERT model)
- RoCBertConfig configuration class: RoCBertModel (RoCBert model)
- RoFormerConfig configuration class: RoFormerModel (RoFormer model)
- RobertaConfig configuration class: RobertaModel (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormModel (RoBERTa-PreLayerNorm model)
- RwkvConfig configuration class: RwkvModel (RWKV model)
- SEWConfig configuration class: SEWModel (SEW model)
- SEWDConfig configuration class: SEWDModel (SEW-D model)
- SamConfig configuration class: SamModel (SAM model)
- SeamlessM4TConfig configuration class: SeamlessM4TModel (SeamlessM4T model)
- SeamlessM4Tv2Config configuration class: SeamlessM4Tv2Model (SeamlessM4Tv2 model)
- SegGptConfig configuration class: SegGptModel (SegGPT model)
- SegformerConfig configuration class: SegformerModel (SegFormer model)
- SiglipConfig configuration class: SiglipModel (SigLIP model)
- SiglipVisionConfig configuration class: SiglipVisionModel (SiglipVisionModel model)
- Speech2TextConfig configuration class: Speech2TextModel (Speech2Text model)
- SpeechT5Config configuration class: SpeechT5Model (SpeechT5 model)
- SplinterConfig configuration class: SplinterModel (Splinter model)
- SqueezeBertConfig configuration class: SqueezeBertModel (SqueezeBERT model)
- StableLmConfig configuration class: StableLmModel (StableLm model)
- Starcoder2Config configuration class: Starcoder2Model (Starcoder2 model)
- SwiftFormerConfig configuration class: SwiftFormerModel (SwiftFormer model)
- Swin2SRConfig configuration class: Swin2SRModel (Swin2SR model)
- SwinConfig configuration class: SwinModel (Swin Transformer model)
- Swinv2Config configuration class: Swinv2Model (Swin Transformer V2 model)
- SwitchTransformersConfig configuration class: SwitchTransformersModel (SwitchTransformers model)
- T5Config configuration class: T5Model (T5 model)
- TableTransformerConfig configuration class: TableTransformerModel (Table Transformer model)
- TapasConfig configuration class: TapasModel (TAPAS model)
- TimeSeriesTransformerConfig configuration class: TimeSeriesTransformerModel (Time Series Transformer model)
- TimesformerConfig configuration class: TimesformerModel (TimeSformer model)
- TimmBackboneConfig configuration class: TimmBackbone (TimmBackbone model)
- TrajectoryTransformerConfig configuration class: TrajectoryTransformerModel (Trajectory Transformer model)
- TransfoXLConfig configuration class: TransfoXLModel (Transformer-XL model)
- TvltConfig configuration class: TvltModel (TVLT model)
- TvpConfig configuration class: TvpModel (TVP model)
- UMT5Config configuration class: UMT5Model (UMT5 model)
- UdopConfig configuration class: UdopModel (UDOP model)
- UniSpeechConfig configuration class: UniSpeechModel (UniSpeech model)
- UniSpeechSatConfig configuration class: UniSpeechSatModel (UniSpeechSat model)
- UnivNetConfig configuration class: UnivNetModel (UnivNet model)
- VanConfig configuration class: VanModel (VAN model)
- ViTConfig configuration class: ViTModel (ViT model)
- ViTHybridConfig configuration class: ViTHybridModel (ViT Hybrid model)
- ViTMAEConfig configuration class: ViTMAEModel (ViTMAE model)
- ViTMSNConfig configuration class: ViTMSNModel (ViTMSN model)
- VideoMAEConfig configuration class: VideoMAEModel (VideoMAE model)
- ViltConfig configuration class: ViltModel (ViLT model)
- VisionTextDualEncoderConfig configuration class: VisionTextDualEncoderModel (VisionTextDualEncoder model)
- VisualBertConfig configuration class: VisualBertModel (VisualBERT model)
- VitDetConfig configuration class: VitDetModel (VitDet model)
- VitsConfig configuration class: VitsModel (VITS model)
- VivitConfig configuration class: VivitModel (ViViT model)
- Wav2Vec2BertConfig configuration class: Wav2Vec2BertModel (Wav2Vec2-BERT model)
- Wav2Vec2Config configuration class: Wav2Vec2Model (Wav2Vec2 model)
- Wav2Vec2ConformerConfig configuration class: Wav2Vec2ConformerModel (Wav2Vec2-Conformer model)
- WavLMConfig configuration class: WavLMModel (WavLM model)
- WhisperConfig configuration class: WhisperModel (Whisper model)
- XCLIPConfig configuration class: XCLIPModel (X-CLIP model)
- XGLMConfig configuration class: XGLMModel (XGLM model)
- XLMConfig configuration class: XLMModel (XLM model)
- XLMProphetNetConfig configuration class: XLMProphetNetModel (XLM-ProphetNet model)
- XLMRobertaConfig configuration class: XLMRobertaModel (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLModel (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetModel (XLNet model)
- XmodConfig configuration class: XmodModel (X-MOD model)
- YolosConfig configuration class: YolosModel (YOLOS model)
- YosoConfig configuration class: YosoModel (YOSO model)
- ZambaConfig configuration class: ZambaModel (Zamba model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个基础模型类。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型文件定义的模型。此选项应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其余部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个基础模型类。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少时,通过回退到使用pretrained_model_name_or_path
上的模式匹配来选择:
- albert — AlbertModel (ALBERT 模型)
- align — AlignModel (ALIGN 模型)
- altclip — AltCLIPModel (AltCLIP 模型)
- audio-spectrogram-transformer — ASTModel (音频频谱变换器模型)
- autoformer — AutoformerModel (Autoformer 模型)
- bark — BarkModel (Bark 模型)
- bart — BartModel (BART 模型)
- beit — BeitModel (BEiT 模型)
- bert — BertModel (BERT 模型)
- bert-generation — BertGenerationEncoder (Bert 生成模型)
- big_bird — BigBirdModel (BigBird 模型)
- bigbird_pegasus — BigBirdPegasusModel (BigBird-Pegasus 模型)
- biogpt — BioGptModel (BioGpt 模型)
- bit — BitModel (BiT 模型)
- blenderbot — BlenderbotModel (Blenderbot 模型)
- blenderbot-small — BlenderbotSmallModel (BlenderbotSmall 模型)
- blip — BlipModel (BLIP 模型)
- blip-2 — Blip2Model (BLIP-2 模型)
- bloom — BloomModel (BLOOM 模型)
- bridgetower — BridgeTowerModel (BridgeTower 模型)
- bros — BrosModel (BROS 模型)
- camembert — CamembertModel (CamemBERT 模型)
- canine — CanineModel (CANINE 模型)
- chameleon — ChameleonModel (变色龙模型)
- chinese_clip — ChineseCLIPModel (中文-CLIP 模型)
- chinese_clip_vision_model — ChineseCLIPVisionModel (ChineseCLIPVisionModel 模型)
- clap — ClapModel (CLAP 模型)
- clip — CLIPModel (CLIP 模型)
- clip_text_model — CLIPTextModel (CLIPTextModel 模型)
- clip_vision_model — CLIPVisionModel (CLIPVisionModel 模型)
- clipseg — CLIPSegModel (CLIPSeg 模型)
- clvp — ClvpModelForConditionalGeneration (CLVP 模型)
- code_llama — LlamaModel (CodeLlama 模型)
- codegen — CodeGenModel (CodeGen 模型)
- cohere — CohereModel (Cohere 模型)
- conditional_detr — ConditionalDetrModel (条件DETR模型)
- convbert — ConvBertModel (ConvBERT 模型)
- convnext — ConvNextModel (ConvNeXT 模型)
- convnextv2 — ConvNextV2Model (ConvNeXTV2 模型)
- cpmant — CpmAntModel (CPM-Ant 模型)
- ctrl — CTRLModel (CTRL 模型)
- cvt — CvtModel (CvT 模型)
- dac — DacModel (DAC 模型)
- data2vec-audio — Data2VecAudioModel (Data2VecAudio 模型)
- data2vec-text — Data2VecTextModel (Data2VecText 模型)
- data2vec-vision — Data2VecVisionModel (Data2VecVision 模型)
- dbrx — DbrxModel (DBRX 模型)
- deberta — DebertaModel (DeBERTa 模型)
- deberta-v2 — DebertaV2Model (DeBERTa-v2 模型)
- decision_transformer — DecisionTransformerModel(决策变换器模型)
- deformable_detr — DeformableDetrModel (可变形DETR模型)
- deit — DeiTModel (DeiT 模型)
- deta — DetaModel (DETA 模型)
- detr — DetrModel (DETR 模型)
- dinat — DinatModel (DiNAT 模型)
- dinov2 — Dinov2Model (DINOv2 模型)
- distilbert — DistilBertModel (DistilBERT 模型)
- donut-swin — DonutSwinModel (DonutSwin 模型)
- dpr — DPRQuestionEncoder (DPR 模型)
- dpt — DPTModel (DPT 模型)
- efficientformer — EfficientFormerModel (EfficientFormer 模型)
- efficientnet — EfficientNetModel (EfficientNet 模型)
- electra — ElectraModel (ELECTRA 模型)
- encodec — EncodecModel (EnCodec 模型)
- ernie — ErnieModel (ERNIE 模型)
- ernie_m — ErnieMModel (ErnieM 模型)
- esm — EsmModel (ESM 模型)
- falcon — FalconModel (Falcon 模型)
- falcon_mamba — FalconMambaModel (FalconMamba 模型)
- fastspeech2_conformer — FastSpeech2ConformerModel (FastSpeech2Conformer 模型)
- flaubert — FlaubertModel (FlauBERT 模型)
- flava — FlavaModel (FLAVA 模型)
- fnet — FNetModel (FNet 模型)
- focalnet — FocalNetModel (FocalNet 模型)
- fsmt — FSMTModel (FairSeq 机器翻译模型)
- funnel — FunnelModel 或 FunnelBaseModel (漏斗变换器模型)
- gemma — GemmaModel (Gemma 模型)
- gemma2 — Gemma2Model (Gemma2 模型)
- git — GitModel (GIT 模型)
- glm — GlmModel (GLM 模型)
- glpn — GLPNModel (GLPN 模型)
- gpt-sw3 — GPT2Model (GPT-Sw3 模型)
- gpt2 — GPT2Model (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeModel (GPTBigCode 模型)
- gpt_neo — GPTNeoModel (GPT Neo 模型)
- gpt_neox — GPTNeoXModel (GPT NeoX 模型)
- gpt_neox_japanese — GPTNeoXJapaneseModel (GPT NeoX 日语模型)
- gptj — GPTJModel (GPT-J 模型)
- gptsan-japanese — GPTSanJapaneseForConditionalGeneration (GPTSAN-日语模型)
- granite — GraniteModel (花岗岩模型)
- granitemoe — GraniteMoeModel (GraniteMoeMoe 模型)
- graphormer — GraphormerModel (Graphormer 模型)
- grounding-dino — GroundingDinoModel (Grounding DINO 模型)
- groupvit — GroupViTModel (GroupViT 模型)
- hiera — HieraModel (Hiera 模型)
- hubert — HubertModel (Hubert 模型)
- ibert — IBertModel (I-BERT 模型)
- idefics — IdeficsModel (IDEFICS 模型)
- idefics2 — Idefics2Model (Idefics2 模型)
- idefics3 — Idefics3Model (Idefics3 模型)
- ijepa — IJepaModel (I-JEPA 模型)
- imagegpt — ImageGPTModel (ImageGPT 模型)
- informer — InformerModel (Informer 模型)
- jamba — JambaModel (Jamba 模型)
- jetmoe — JetMoeModel (JetMoe 模型)
- 点唱机 — JukeboxModel (点唱机模型)
- kosmos-2 — Kosmos2Model (KOSMOS-2 模型)
- layoutlm — LayoutLMModel (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2Model (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3Model (LayoutLMv3 模型)
- led — LEDModel (LED 模型)
- levit — LevitModel (LeViT 模型)
- lilt — LiltModel (LiLT 模型)
- llama — LlamaModel (LLaMA 模型)
- longformer — LongformerModel (Longformer 模型)
- longt5 — LongT5Model (LongT5 模型)
- luke — LukeModel (LUKE 模型)
- lxmert — LxmertModel (LXMERT 模型)
- m2m_100 — M2M100Model (M2M100 模型)
- mamba — MambaModel (Mamba 模型)
- mamba2 — Mamba2Model (mamba2 模型)
- marian — MarianModel (Marian 模型)
- markuplm — MarkupLMModel (MarkupLM 模型)
- mask2former — Mask2FormerModel (Mask2Former 模型)
- maskformer — MaskFormerModel (MaskFormer 模型)
- maskformer-swin —
MaskFormerSwinModel
(MaskFormerSwin 模型) - mbart — MBartModel (mBART 模型)
- mctct — MCTCTModel (M-CTC-T 模型)
- mega — MegaModel (MEGA 模型)
- megatron-bert — MegatronBertModel (Megatron-BERT 模型)
- mgp-str — MgpstrForSceneTextRecognition (MGP-STR 模型)
- mimi — MimiModel (Mimi 模型)
- mistral — MistralModel (Mistral 模型)
- mixtral — MixtralModel (Mixtral 模型)
- mobilebert — MobileBertModel (MobileBERT 模型)
- mobilenet_v1 — MobileNetV1Model (MobileNetV1 模型)
- mobilenet_v2 — MobileNetV2Model (MobileNetV2 模型)
- mobilevit — MobileViTModel (MobileViT 模型)
- mobilevitv2 — MobileViTV2Model (MobileViTV2 模型)
- moshi — MoshiModel (Moshi 模型)
- mpnet — MPNetModel (MPNet 模型)
- mpt — MptModel (MPT 模型)
- mra — MraModel (MRA 模型)
- mt5 — MT5Model (MT5 模型)
- musicgen — MusicgenModel (MusicGen 模型)
- musicgen_melody — MusicgenMelodyModel (MusicGen 旋律模型)
- mvp — MvpModel (MVP 模型)
- nat — NatModel (NAT 模型)
- nemotron — NemotronModel (Nemotron 模型)
- nezha — NezhaModel (哪吒模型)
- nllb-moe — NllbMoeModel (NLLB-MOE 模型)
- nystromformer — NystromformerModel (Nyströmformer 模型)
- olmo — OlmoModel (OLMo 模型)
- olmo2 — Olmo2Model (OLMo2 模型)
- olmoe — OlmoeModel (OLMoE 模型)
- omdet-turbo — OmDetTurboForObjectDetection (OmDet-Turbo 模型)
- oneformer — OneFormerModel (OneFormer 模型)
- open-llama — OpenLlamaModel (OpenLlama 模型)
- openai-gpt — OpenAIGPTModel (OpenAI GPT 模型)
- opt — OPTModel (OPT 模型)
- owlv2 — Owlv2Model (OWLv2 模型)
- owlvit — OwlViTModel (OWL-ViT 模型)
- patchtsmixer — PatchTSMixerModel (PatchTSMixer 模型)
- patchtst — PatchTSTModel (PatchTST 模型)
- pegasus — PegasusModel (Pegasus 模型)
- pegasus_x — PegasusXModel (PEGASUS-X 模型)
- perceiver — PerceiverModel (感知器模型)
- 柿子 — PersimmonModel (柿子模型)
- phi — PhiModel (Phi 模型)
- phi3 — Phi3Model (Phi3 模型)
- phimoe — PhimoeModel (Phimoe 模型)
- pixtral — PixtralVisionModel (Pixtral 模型)
- plbart — PLBartModel (PLBart 模型)
- poolformer — PoolFormerModel (PoolFormer 模型)
- prophetnet — ProphetNetModel (ProphetNet 模型)
- pvt — PvtModel (PVT 模型)
- pvt_v2 — PvtV2Model (PVTv2 模型)
- qdqbert — QDQBertModel (QDQBert 模型)
- qwen2 — Qwen2Model (Qwen2 模型)
- qwen2_audio_encoder —
Qwen2AudioEncoder
(Qwen2AudioEncoder 模型) - qwen2_moe — Qwen2MoeModel (Qwen2MoE 模型)
- qwen2_vl — Qwen2VLModel (Qwen2VL 模型)
- recurrent_gemma — RecurrentGemmaModel (RecurrentGemma 模型)
- reformer — ReformerModel (Reformer 模型)
- regnet — RegNetModel (RegNet 模型)
- rembert — RemBertModel (RemBERT 模型)
- resnet — ResNetModel (ResNet 模型)
- retribert — RetriBertModel (RetriBERT 模型)
- roberta — RobertaModel (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormModel (RoBERTa-预层归一化模型)
- roc_bert — RoCBertModel (RoCBert 模型)
- roformer — RoFormerModel (RoFormer 模型)
- rt_detr — RTDetrModel (RT-DETR 模型)
- rwkv — RwkvModel (RWKV 模型)
- sam — SamModel (SAM 模型)
- seamless_m4t — SeamlessM4TModel (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4Tv2Model (SeamlessM4Tv2 模型)
- segformer — SegformerModel (SegFormer 模型)
- seggpt — SegGptModel (SegGPT 模型)
- sew — SEWModel (SEW 模型)
- sew-d — SEWDModel (SEW-D 模型)
- siglip — SiglipModel (SigLIP 模型)
- siglip_vision_model — SiglipVisionModel (SiglipVisionModel 模型)
- speech_to_text — Speech2TextModel (语音转文本模型)
- speecht5 — SpeechT5Model (SpeechT5 模型)
- splinter — SplinterModel (Splinter 模型)
- squeezebert — SqueezeBertModel (SqueezeBERT 模型)
- stablelm — StableLmModel (StableLm 模型)
- starcoder2 — Starcoder2Model (Starcoder2 模型)
- swiftformer — SwiftFormerModel (SwiftFormer 模型)
- swin — SwinModel (Swin Transformer 模型)
- swin2sr — Swin2SRModel (Swin2SR 模型)
- swinv2 — Swinv2Model (Swin Transformer V2 模型)
- switch_transformers — SwitchTransformersModel (SwitchTransformers 模型)
- t5 — T5Model (T5 模型)
- table-transformer — TableTransformerModel(表格转换器模型)
- tapas — TapasModel (TAPAS 模型)
- time_series_transformer — TimeSeriesTransformerModel (时间序列变换器模型)
- timesformer — TimesformerModel (TimeSformer 模型)
- timm_backbone — TimmBackbone (TimmBackbone 模型)
- trajectory_transformer — TrajectoryTransformerModel (轨迹变换器模型)
- transfo-xl — TransfoXLModel (Transformer-XL 模型)
- tvlt — TvltModel (TVLT 模型)
- tvp — TvpModel (TVP 模型)
- udop — UdopModel (UDOP 模型)
- umt5 — UMT5Model (UMT5 模型)
- unispeech — UniSpeechModel (UniSpeech 模型)
- unispeech-sat — UniSpeechSatModel (UniSpeechSat 模型)
- univnet — UnivNetModel (UnivNet 模型)
- van — VanModel (VAN 模型)
- videomae — VideoMAEModel (VideoMAE 模型)
- vilt — ViltModel (ViLT 模型)
- vision-text-dual-encoder — VisionTextDualEncoderModel (视觉文本双编码器模型)
- visual_bert — VisualBertModel (VisualBERT 模型)
- vit — ViTModel (ViT 模型)
- vit_hybrid — ViTHybridModel (ViT 混合模型)
- vit_mae — ViTMAEModel (ViTMAE 模型)
- vit_msn — ViTMSNModel (ViTMSN 模型)
- vitdet — VitDetModel (VitDet 模型)
- vits — VitsModel (VITS 模型)
- vivit — VivitModel (ViViT 模型)
- wav2vec2 — Wav2Vec2Model (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertModel (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerModel (Wav2Vec2-Conformer 模型)
- wavlm — WavLMModel (WavLM 模型)
- whisper — WhisperModel (Whisper 模型)
- xclip — XCLIPModel (X-CLIP 模型)
- xglm — XGLMModel (XGLM 模型)
- xlm — XLMModel (XLM 模型)
- xlm-prophetnet — XLMProphetNetModel (XLM-ProphetNet 模型)
- xlm-roberta — XLMRobertaModel (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLModel (XLM-RoBERTa-XL 模型)
- xlnet — XLNetModel (XLNet 模型)
- xmod — XmodModel (X-MOD 模型)
- yolos — YolosModel (YOLOS 模型)
- yoso — YosoModel (YOSO 模型)
- zamba — ZambaModel (Zamba 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModel.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModel.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModel
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将作为库的基础模型类之一进行实例化。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: TFAlbertModel (ALBERT model)
- BartConfig configuration class: TFBartModel (BART model)
- BertConfig configuration class: TFBertModel (BERT model)
- BlenderbotConfig configuration class: TFBlenderbotModel (Blenderbot model)
- BlenderbotSmallConfig configuration class: TFBlenderbotSmallModel (BlenderbotSmall model)
- BlipConfig configuration class: TFBlipModel (BLIP model)
- CLIPConfig configuration class: TFCLIPModel (CLIP model)
- CTRLConfig configuration class: TFCTRLModel (CTRL model)
- CamembertConfig configuration class: TFCamembertModel (CamemBERT model)
- ConvBertConfig configuration class: TFConvBertModel (ConvBERT model)
- ConvNextConfig configuration class: TFConvNextModel (ConvNeXT model)
- ConvNextV2Config configuration class: TFConvNextV2Model (ConvNeXTV2 model)
- CvtConfig configuration class: TFCvtModel (CvT model)
- DPRConfig configuration class: TFDPRQuestionEncoder (DPR model)
- Data2VecVisionConfig configuration class: TFData2VecVisionModel (Data2VecVision model)
- DebertaConfig configuration class: TFDebertaModel (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2Model (DeBERTa-v2 model)
- DeiTConfig configuration class: TFDeiTModel (DeiT model)
- DistilBertConfig configuration class: TFDistilBertModel (DistilBERT model)
- EfficientFormerConfig configuration class: TFEfficientFormerModel (EfficientFormer model)
- ElectraConfig configuration class: TFElectraModel (ELECTRA model)
- EsmConfig configuration class: TFEsmModel (ESM model)
- FlaubertConfig configuration class: TFFlaubertModel (FlauBERT model)
- FunnelConfig configuration class: TFFunnelModel or TFFunnelBaseModel (Funnel Transformer model)
- GPT2Config configuration class: TFGPT2Model (OpenAI GPT-2 model)
- GPTJConfig configuration class: TFGPTJModel (GPT-J model)
- GroupViTConfig configuration class: TFGroupViTModel (GroupViT model)
- HubertConfig configuration class: TFHubertModel (Hubert model)
- IdeficsConfig configuration class: TFIdeficsModel (IDEFICS model)
- LEDConfig configuration class: TFLEDModel (LED model)
- LayoutLMConfig configuration class: TFLayoutLMModel (LayoutLM model)
- LayoutLMv3Config configuration class: TFLayoutLMv3Model (LayoutLMv3 model)
- LongformerConfig configuration class: TFLongformerModel (Longformer model)
- LxmertConfig configuration class: TFLxmertModel (LXMERT model)
- MBartConfig configuration class: TFMBartModel (mBART model)
- MPNetConfig configuration class: TFMPNetModel (MPNet model)
- MT5Config configuration class: TFMT5Model (MT5 model)
- MarianConfig configuration class: TFMarianModel (Marian model)
- MistralConfig configuration class: TFMistralModel (Mistral model)
- MobileBertConfig configuration class: TFMobileBertModel (MobileBERT model)
- MobileViTConfig configuration class: TFMobileViTModel (MobileViT model)
- OPTConfig configuration class: TFOPTModel (OPT model)
- OpenAIGPTConfig configuration class: TFOpenAIGPTModel (OpenAI GPT model)
- PegasusConfig configuration class: TFPegasusModel (Pegasus model)
- RegNetConfig configuration class: TFRegNetModel (RegNet model)
- RemBertConfig configuration class: TFRemBertModel (RemBERT model)
- ResNetConfig configuration class: TFResNetModel (ResNet model)
- RoFormerConfig configuration class: TFRoFormerModel (RoFormer model)
- RobertaConfig configuration class: TFRobertaModel (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormModel (RoBERTa-PreLayerNorm model)
- SamConfig configuration class: TFSamModel (SAM model)
- SegformerConfig configuration class: TFSegformerModel (SegFormer model)
- Speech2TextConfig configuration class: TFSpeech2TextModel (Speech2Text model)
- SwiftFormerConfig configuration class: TFSwiftFormerModel (SwiftFormer model)
- SwinConfig configuration class: TFSwinModel (Swin Transformer model)
- T5Config configuration class: TFT5Model (T5 model)
- TapasConfig configuration class: TFTapasModel (TAPAS model)
- TransfoXLConfig configuration class: TFTransfoXLModel (Transformer-XL model)
- ViTConfig configuration class: TFViTModel (ViT model)
- ViTMAEConfig configuration class: TFViTMAEModel (ViTMAE model)
- VisionTextDualEncoderConfig configuration class: TFVisionTextDualEncoderModel (VisionTextDualEncoder model)
- Wav2Vec2Config configuration class: TFWav2Vec2Model (Wav2Vec2 model)
- WhisperConfig configuration class: TFWhisperModel (Whisper model)
- XGLMConfig configuration class: TFXGLMModel (XGLM model)
- XLMConfig configuration class: TFXLMModel (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaModel (XLM-RoBERTa model)
- XLNetConfig configuration class: TFXLNetModel (XLNet model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个基础模型类。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,因此revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个基础模型类。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — TFAlbertModel (ALBERT 模型)
- bart — TFBartModel (BART 模型)
- bert — TFBertModel (BERT 模型)
- blenderbot — TFBlenderbotModel (Blenderbot 模型)
- blenderbot-small — TFBlenderbotSmallModel (BlenderbotSmall 模型)
- blip — TFBlipModel (BLIP 模型)
- camembert — TFCamembertModel (CamemBERT 模型)
- clip — TFCLIPModel (CLIP 模型)
- convbert — TFConvBertModel (ConvBERT 模型)
- convnext — TFConvNextModel (ConvNeXT 模型)
- convnextv2 — TFConvNextV2Model (ConvNeXTV2 模型)
- ctrl — TFCTRLModel (CTRL 模型)
- cvt — TFCvtModel (CvT 模型)
- data2vec-vision — TFData2VecVisionModel (Data2VecVision 模型)
- deberta — TFDebertaModel (DeBERTa 模型)
- deberta-v2 — TFDebertaV2Model (DeBERTa-v2 模型)
- deit — TFDeiTModel (DeiT 模型)
- distilbert — TFDistilBertModel (DistilBERT 模型)
- dpr — TFDPRQuestionEncoder (DPR 模型)
- efficientformer — TFEfficientFormerModel (EfficientFormer 模型)
- electra — TFElectraModel (ELECTRA 模型)
- esm — TFEsmModel (ESM 模型)
- flaubert — TFFlaubertModel (FlauBERT 模型)
- funnel — TFFunnelModel 或 TFFunnelBaseModel (漏斗变换器模型)
- gpt-sw3 — TFGPT2Model (GPT-Sw3 模型)
- gpt2 — TFGPT2Model (OpenAI GPT-2 模型)
- gptj — TFGPTJModel (GPT-J 模型)
- groupvit — TFGroupViTModel (GroupViT 模型)
- hubert — TFHubertModel (Hubert 模型)
- idefics — TFIdeficsModel (IDEFICS 模型)
- layoutlm — TFLayoutLMModel (LayoutLM 模型)
- layoutlmv3 — TFLayoutLMv3Model (LayoutLMv3 模型)
- led — TFLEDModel (LED 模型)
- longformer — TFLongformerModel (Longformer 模型)
- lxmert — TFLxmertModel (LXMERT 模型)
- marian — TFMarianModel (Marian 模型)
- mbart — TFMBartModel (mBART 模型)
- mistral — TFMistralModel (Mistral 模型)
- mobilebert — TFMobileBertModel (MobileBERT 模型)
- mobilevit — TFMobileViTModel (MobileViT 模型)
- mpnet — TFMPNetModel (MPNet 模型)
- mt5 — TFMT5Model (MT5 模型)
- openai-gpt — TFOpenAIGPTModel (OpenAI GPT 模型)
- opt — TFOPTModel (OPT 模型)
- pegasus — TFPegasusModel (Pegasus 模型)
- regnet — TFRegNetModel (RegNet 模型)
- rembert — TFRemBertModel (RemBERT 模型)
- resnet — TFResNetModel (ResNet 模型)
- roberta — TFRobertaModel (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerModel (RoFormer 模型)
- sam — TFSamModel (SAM 模型)
- segformer — TFSegformerModel (SegFormer 模型)
- speech_to_text — TFSpeech2TextModel (语音转文本模型)
- swiftformer — TFSwiftFormerModel (SwiftFormer 模型)
- swin — TFSwinModel (Swin Transformer 模型)
- t5 — TFT5Model (T5 模型)
- tapas — TFTapasModel (TAPAS 模型)
- transfo-xl — TFTransfoXLModel (Transformer-XL 模型)
- vision-text-dual-encoder — TFVisionTextDualEncoderModel (视觉文本双编码器模型)
- vit — TFViTModel (ViT 模型)
- vit_mae — TFViTMAEModel (ViTMAE 模型)
- wav2vec2 — TFWav2Vec2Model (Wav2Vec2 模型)
- whisper — TFWhisperModel (Whisper 模型)
- xglm — TFXGLMModel (XGLM 模型)
- xlm — TFXLMModel (XLM 模型)
- xlm-roberta — TFXLMRobertaModel (XLM-RoBERTa 模型)
- xlnet — TFXLNetModel (XLNet 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModel.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModel
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将作为库的基础模型类之一进行实例化。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: FlaxAlbertModel (ALBERT model)
- BartConfig configuration class: FlaxBartModel (BART model)
- BeitConfig configuration class: FlaxBeitModel (BEiT model)
- BertConfig configuration class: FlaxBertModel (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdModel (BigBird model)
- BlenderbotConfig configuration class: FlaxBlenderbotModel (Blenderbot model)
- BlenderbotSmallConfig configuration class: FlaxBlenderbotSmallModel (BlenderbotSmall model)
- BloomConfig configuration class: FlaxBloomModel (BLOOM model)
- CLIPConfig configuration class: FlaxCLIPModel (CLIP model)
- Dinov2Config configuration class: FlaxDinov2Model (DINOv2 model)
- DistilBertConfig configuration class: FlaxDistilBertModel (DistilBERT model)
- ElectraConfig configuration class: FlaxElectraModel (ELECTRA model)
- GPT2Config configuration class: FlaxGPT2Model (OpenAI GPT-2 model)
- GPTJConfig configuration class: FlaxGPTJModel (GPT-J model)
- GPTNeoConfig configuration class: FlaxGPTNeoModel (GPT Neo model)
- GemmaConfig configuration class: FlaxGemmaModel (Gemma model)
- LlamaConfig configuration class: FlaxLlamaModel (LLaMA model)
- LongT5Config configuration class: FlaxLongT5Model (LongT5 model)
- MBartConfig configuration class: FlaxMBartModel (mBART model)
- MT5Config configuration class: FlaxMT5Model (MT5 model)
- MarianConfig configuration class: FlaxMarianModel (Marian model)
- MistralConfig configuration class: FlaxMistralModel (Mistral model)
- OPTConfig configuration class: FlaxOPTModel (OPT model)
- PegasusConfig configuration class: FlaxPegasusModel (Pegasus model)
- RegNetConfig configuration class: FlaxRegNetModel (RegNet model)
- ResNetConfig configuration class: FlaxResNetModel (ResNet model)
- RoFormerConfig configuration class: FlaxRoFormerModel (RoFormer model)
- RobertaConfig configuration class: FlaxRobertaModel (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormModel (RoBERTa-PreLayerNorm model)
- T5Config configuration class: FlaxT5Model (T5 model)
- ViTConfig configuration class: FlaxViTModel (ViT model)
- VisionTextDualEncoderConfig configuration class: FlaxVisionTextDualEncoderModel (VisionTextDualEncoder model)
- Wav2Vec2Config configuration class: FlaxWav2Vec2Model (Wav2Vec2 model)
- WhisperConfig configuration class: FlaxWhisperModel (Whisper model)
- XGLMConfig configuration class: FlaxXGLMModel (XGLM model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaModel (XLM-RoBERTa model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个基础模型类。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能时都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每次请求时使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在 您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个基础模型类。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少时,通过回退到使用pretrained_model_name_or_path
上的模式匹配来选择:
- albert — FlaxAlbertModel (ALBERT 模型)
- bart — FlaxBartModel (BART 模型)
- beit — FlaxBeitModel (BEiT 模型)
- bert — FlaxBertModel (BERT 模型)
- big_bird — FlaxBigBirdModel (BigBird 模型)
- blenderbot — FlaxBlenderbotModel (Blenderbot 模型)
- blenderbot-small — FlaxBlenderbotSmallModel (BlenderbotSmall 模型)
- bloom — FlaxBloomModel (BLOOM 模型)
- clip — FlaxCLIPModel (CLIP 模型)
- dinov2 — FlaxDinov2Model (DINOv2 模型)
- distilbert — FlaxDistilBertModel (DistilBERT 模型)
- electra — FlaxElectraModel (ELECTRA 模型)
- gemma — FlaxGemmaModel (Gemma 模型)
- gpt-sw3 — FlaxGPT2Model (GPT-Sw3 模型)
- gpt2 — FlaxGPT2Model (OpenAI GPT-2 模型)
- gpt_neo — FlaxGPTNeoModel (GPT Neo 模型)
- gptj — FlaxGPTJModel (GPT-J 模型)
- llama — FlaxLlamaModel (LLaMA 模型)
- longt5 — FlaxLongT5Model (LongT5 模型)
- marian — FlaxMarianModel (Marian 模型)
- mbart — FlaxMBartModel (mBART 模型)
- mistral — FlaxMistralModel (Mistral 模型)
- mt5 — FlaxMT5Model (MT5 模型)
- opt — FlaxOPTModel (OPT 模型)
- pegasus — FlaxPegasusModel (Pegasus 模型)
- regnet — FlaxRegNetModel (RegNet 模型)
- resnet — FlaxResNetModel (ResNet 模型)
- roberta — FlaxRobertaModel (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerModel (RoFormer 模型)
- t5 — FlaxT5Model (T5 模型)
- vision-text-dual-encoder — FlaxVisionTextDualEncoderModel (视觉文本双编码器模型)
- vit — FlaxViTModel (ViT 模型)
- wav2vec2 — FlaxWav2Vec2Model (Wav2Vec2 模型)
- whisper — FlaxWhisperModel (Whisper 模型)
- xglm — FlaxXGLMModel (XGLM 模型)
- xlm-roberta — FlaxXLMRobertaModel (XLM-RoBERTa 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModel.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModel.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
通用预训练类
以下自动类可用于实例化带有预训练头部的模型。
AutoModelForPreTraining
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有预训练头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: AlbertForPreTraining (ALBERT model)
- BartConfig configuration class: BartForConditionalGeneration (BART model)
- BertConfig configuration class: BertForPreTraining (BERT model)
- BigBirdConfig configuration class: BigBirdForPreTraining (BigBird model)
- BloomConfig configuration class: BloomForCausalLM (BLOOM model)
- CTRLConfig configuration class: CTRLLMHeadModel (CTRL model)
- CamembertConfig configuration class: CamembertForMaskedLM (CamemBERT model)
- Data2VecTextConfig configuration class: Data2VecTextForMaskedLM (Data2VecText model)
- DebertaConfig configuration class: DebertaForMaskedLM (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForMaskedLM (DeBERTa-v2 model)
- DistilBertConfig configuration class: DistilBertForMaskedLM (DistilBERT model)
- ElectraConfig configuration class: ElectraForPreTraining (ELECTRA model)
- ErnieConfig configuration class: ErnieForPreTraining (ERNIE model)
- FNetConfig configuration class: FNetForPreTraining (FNet model)
- FSMTConfig configuration class: FSMTForConditionalGeneration (FairSeq Machine-Translation model)
- FalconMambaConfig configuration class: FalconMambaForCausalLM (FalconMamba model)
- FlaubertConfig configuration class: FlaubertWithLMHeadModel (FlauBERT model)
- FlavaConfig configuration class: FlavaForPreTraining (FLAVA model)
- FunnelConfig configuration class: FunnelForPreTraining (Funnel Transformer model)
- GPT2Config configuration class: GPT2LMHeadModel (OpenAI GPT-2 model)
- GPTBigCodeConfig configuration class: GPTBigCodeForCausalLM (GPTBigCode model)
- GPTSanJapaneseConfig configuration class: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese model)
- HieraConfig configuration class: HieraForPreTraining (Hiera model)
- IBertConfig configuration class: IBertForMaskedLM (I-BERT model)
- Idefics2Config configuration class: Idefics2ForConditionalGeneration (Idefics2 model)
- Idefics3Config configuration class: Idefics3ForConditionalGeneration (Idefics3 model)
- IdeficsConfig configuration class: IdeficsForVisionText2Text (IDEFICS model)
- LayoutLMConfig configuration class: LayoutLMForMaskedLM (LayoutLM model)
- LlavaConfig configuration class: LlavaForConditionalGeneration (LLaVa model)
- LlavaNextConfig configuration class: LlavaNextForConditionalGeneration (LLaVA-NeXT model)
- LlavaNextVideoConfig configuration class: LlavaNextVideoForConditionalGeneration (LLaVa-NeXT-Video model)
- LlavaOnevisionConfig configuration class: LlavaOnevisionForConditionalGeneration (LLaVA-Onevision model)
- LongformerConfig configuration class: LongformerForMaskedLM (Longformer model)
- LukeConfig configuration class: LukeForMaskedLM (LUKE model)
- LxmertConfig configuration class: LxmertForPreTraining (LXMERT model)
- MPNetConfig configuration class: MPNetForMaskedLM (MPNet model)
- Mamba2Config configuration class: Mamba2ForCausalLM (mamba2 model)
- MambaConfig configuration class: MambaForCausalLM (Mamba model)
- MegaConfig configuration class: MegaForMaskedLM (MEGA model)
- MegatronBertConfig configuration class: MegatronBertForPreTraining (Megatron-BERT model)
- MllamaConfig configuration class: MllamaForConditionalGeneration (Mllama model)
- MobileBertConfig configuration class: MobileBertForPreTraining (MobileBERT model)
- MptConfig configuration class: MptForCausalLM (MPT model)
- MraConfig configuration class: MraForMaskedLM (MRA model)
- MvpConfig configuration class: MvpForConditionalGeneration (MVP model)
- NezhaConfig configuration class: NezhaForPreTraining (Nezha model)
- NllbMoeConfig configuration class: NllbMoeForConditionalGeneration (NLLB-MOE model)
- OpenAIGPTConfig configuration class: OpenAIGPTLMHeadModel (OpenAI GPT model)
- PaliGemmaConfig configuration class: PaliGemmaForConditionalGeneration (PaliGemma model)
- Qwen2AudioConfig configuration class: Qwen2AudioForConditionalGeneration (Qwen2Audio model)
- RetriBertConfig configuration class: RetriBertModel (RetriBERT model)
- RoCBertConfig configuration class: RoCBertForPreTraining (RoCBert model)
- RobertaConfig configuration class: RobertaForMaskedLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- RwkvConfig configuration class: RwkvForCausalLM (RWKV model)
- SplinterConfig configuration class: SplinterForPreTraining (Splinter model)
- SqueezeBertConfig configuration class: SqueezeBertForMaskedLM (SqueezeBERT model)
- SwitchTransformersConfig configuration class: SwitchTransformersForConditionalGeneration (SwitchTransformers model)
- T5Config configuration class: T5ForConditionalGeneration (T5 model)
- TapasConfig configuration class: TapasForMaskedLM (TAPAS model)
- TransfoXLConfig configuration class: TransfoXLLMHeadModel (Transformer-XL model)
- TvltConfig configuration class: TvltForPreTraining (TVLT model)
- UniSpeechConfig configuration class: UniSpeechForPreTraining (UniSpeech model)
- UniSpeechSatConfig configuration class: UniSpeechSatForPreTraining (UniSpeechSat model)
- ViTMAEConfig configuration class: ViTMAEForPreTraining (ViTMAE model)
- VideoLlavaConfig configuration class: VideoLlavaForConditionalGeneration (VideoLlava model)
- VideoMAEConfig configuration class: VideoMAEForPreTraining (VideoMAE model)
- VipLlavaConfig configuration class: VipLlavaForConditionalGeneration (VipLlava model)
- VisualBertConfig configuration class: VisualBertForPreTraining (VisualBERT model)
- Wav2Vec2Config configuration class: Wav2Vec2ForPreTraining (Wav2Vec2 model)
- Wav2Vec2ConformerConfig configuration class: Wav2Vec2ConformerForPreTraining (Wav2Vec2-Conformer model)
- XLMConfig configuration class: XLMWithLMHeadModel (XLM model)
- XLMRobertaConfig configuration class: XLMRobertaForMaskedLM (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetLMHeadModel (XLNet model)
- XmodConfig configuration class: XmodForMaskedLM (X-MOD model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置实例化库中的一个模型类(带有预训练头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,因此revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型中实例化库中的一个模型类(带有预训练头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — AlbertForPreTraining (ALBERT 模型)
- bart — BartForConditionalGeneration (BART 模型)
- bert — BertForPreTraining (BERT 模型)
- big_bird — BigBirdForPreTraining (BigBird 模型)
- bloom — BloomForCausalLM (BLOOM 模型)
- camembert — CamembertForMaskedLM (CamemBERT 模型)
- ctrl — CTRLLMHeadModel (CTRL 模型)
- data2vec-text — Data2VecTextForMaskedLM (Data2VecText 模型)
- deberta — DebertaForMaskedLM (DeBERTa 模型)
- deberta-v2 — DebertaV2ForMaskedLM (DeBERTa-v2 模型)
- distilbert — DistilBertForMaskedLM (DistilBERT 模型)
- electra — ElectraForPreTraining (ELECTRA 模型)
- ernie — ErnieForPreTraining (ERNIE 模型)
- falcon_mamba — FalconMambaForCausalLM (FalconMamba 模型)
- flaubert — FlaubertWithLMHeadModel (FlauBERT 模型)
- flava — FlavaForPreTraining (FLAVA 模型)
- fnet — FNetForPreTraining (FNet 模型)
- fsmt — FSMTForConditionalGeneration (FairSeq 机器翻译模型)
- funnel — FunnelForPreTraining (漏斗变压器模型)
- gpt-sw3 — GPT2LMHeadModel (GPT-Sw3 模型)
- gpt2 — GPT2LMHeadModel (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeForCausalLM (GPTBigCode 模型)
- gptsan-japanese — GPTSanJapaneseForConditionalGeneration (GPTSAN-日语模型)
- hiera — HieraForPreTraining (Hiera 模型)
- ibert — IBertForMaskedLM (I-BERT 模型)
- idefics — IdeficsForVisionText2Text (IDEFICS 模型)
- idefics2 — Idefics2ForConditionalGeneration (Idefics2 模型)
- idefics3 — Idefics3ForConditionalGeneration (Idefics3 模型)
- layoutlm — LayoutLMForMaskedLM (LayoutLM 模型)
- llava — LlavaForConditionalGeneration (LLaVa 模型)
- llava_next — LlavaNextForConditionalGeneration (LLaVA-NEXT 模型)
- llava_next_video — LlavaNextVideoForConditionalGeneration (LLaVa-NeXT-Video 模型)
- llava_onevision — LlavaOnevisionForConditionalGeneration (LLaVA-Onevision 模型)
- longformer — LongformerForMaskedLM (Longformer 模型)
- luke — LukeForMaskedLM (LUKE 模型)
- lxmert — LxmertForPreTraining (LXMERT 模型)
- mamba — MambaForCausalLM (Mamba 模型)
- mamba2 — Mamba2ForCausalLM (mamba2 模型)
- mega — MegaForMaskedLM (MEGA 模型)
- megatron-bert — MegatronBertForPreTraining (Megatron-BERT 模型)
- mllama — MllamaForConditionalGeneration (Mllama 模型)
- mobilebert — MobileBertForPreTraining (MobileBERT 模型)
- mpnet — MPNetForMaskedLM (MPNet 模型)
- mpt — MptForCausalLM (MPT 模型)
- mra — MraForMaskedLM (MRA 模型)
- mvp — MvpForConditionalGeneration (MVP 模型)
- nezha — NezhaForPreTraining (Nezha 模型)
- nllb-moe — NllbMoeForConditionalGeneration (NLLB-MOE 模型)
- openai-gpt — OpenAIGPTLMHeadModel (OpenAI GPT 模型)
- paligemma — PaliGemmaForConditionalGeneration (PaliGemma 模型)
- qwen2_audio — Qwen2AudioForConditionalGeneration (Qwen2Audio 模型)
- retribert — RetriBertModel (RetriBERT 模型)
- roberta — RobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForPreTraining (RoCBert 模型)
- rwkv — RwkvForCausalLM (RWKV 模型)
- splinter — SplinterForPreTraining (Splinter 模型)
- squeezebert — SqueezeBertForMaskedLM (SqueezeBERT 模型)
- switch_transformers — SwitchTransformersForConditionalGeneration (SwitchTransformers 模型)
- t5 — T5ForConditionalGeneration (T5 模型)
- tapas — TapasForMaskedLM (TAPAS 模型)
- transfo-xl — TransfoXLLMHeadModel (Transformer-XL 模型)
- tvlt — TvltForPreTraining (TVLT 模型)
- unispeech — UniSpeechForPreTraining (UniSpeech 模型)
- unispeech-sat — UniSpeechSatForPreTraining (UniSpeechSat 模型)
- video_llava — VideoLlavaForConditionalGeneration (VideoLlava 模型)
- videomae — VideoMAEForPreTraining (VideoMAE 模型)
- vipllava — VipLlavaForConditionalGeneration (VipLlava 模型)
- visual_bert — VisualBertForPreTraining (VisualBERT 模型)
- vit_mae — ViTMAEForPreTraining (ViTMAE 模型)
- wav2vec2 — Wav2Vec2ForPreTraining (Wav2Vec2 模型)
- wav2vec2-conformer — Wav2Vec2ConformerForPreTraining (Wav2Vec2-Conformer 模型)
- xlm — XLMWithLMHeadModel (XLM 模型)
- xlm-roberta — XLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL 模型)
- xlnet — XLNetLMHeadModel (XLNet 模型)
- xmod — XmodForMaskedLM (X-MOD 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForPreTraining.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForPreTraining
这是一个通用的模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有预训练头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: TFAlbertForPreTraining (ALBERT model)
- BartConfig configuration class: TFBartForConditionalGeneration (BART model)
- BertConfig configuration class: TFBertForPreTraining (BERT model)
- CTRLConfig configuration class: TFCTRLLMHeadModel (CTRL model)
- CamembertConfig configuration class: TFCamembertForMaskedLM (CamemBERT model)
- DistilBertConfig configuration class: TFDistilBertForMaskedLM (DistilBERT model)
- ElectraConfig configuration class: TFElectraForPreTraining (ELECTRA model)
- FlaubertConfig configuration class: TFFlaubertWithLMHeadModel (FlauBERT model)
- FunnelConfig configuration class: TFFunnelForPreTraining (Funnel Transformer model)
- GPT2Config configuration class: TFGPT2LMHeadModel (OpenAI GPT-2 model)
- IdeficsConfig configuration class: TFIdeficsForVisionText2Text (IDEFICS model)
- LayoutLMConfig configuration class: TFLayoutLMForMaskedLM (LayoutLM model)
- LxmertConfig configuration class: TFLxmertForPreTraining (LXMERT model)
- MPNetConfig configuration class: TFMPNetForMaskedLM (MPNet model)
- MobileBertConfig configuration class: TFMobileBertForPreTraining (MobileBERT model)
- OpenAIGPTConfig configuration class: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
- RobertaConfig configuration class: TFRobertaForMaskedLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- T5Config configuration class: TFT5ForConditionalGeneration (T5 model)
- TapasConfig configuration class: TFTapasForMaskedLM (TAPAS model)
- TransfoXLConfig configuration class: TFTransfoXLLMHeadModel (Transformer-XL model)
- ViTMAEConfig configuration class: TFViTMAEForPreTraining (ViTMAE model)
- XLMConfig configuration class: TFXLMWithLMHeadModel (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaForMaskedLM (XLM-RoBERTa model)
- XLNetConfig configuration class: TFXLNetLMHeadModel (XLNet model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置实例化库中的一个模型类(带有预训练头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载都会在可能的情况下恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型中实例化库中的一个模型类(带有预训练头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — TFAlbertForPreTraining (ALBERT 模型)
- bart — TFBartForConditionalGeneration (BART 模型)
- bert — TFBertForPreTraining (BERT 模型)
- camembert — TFCamembertForMaskedLM (CamemBERT 模型)
- ctrl — TFCTRLLMHeadModel (CTRL 模型)
- distilbert — TFDistilBertForMaskedLM (DistilBERT 模型)
- electra — TFElectraForPreTraining (ELECTRA 模型)
- flaubert — TFFlaubertWithLMHeadModel (FlauBERT 模型)
- funnel — TFFunnelForPreTraining (漏斗变压器模型)
- gpt-sw3 — TFGPT2LMHeadModel (GPT-Sw3 模型)
- gpt2 — TFGPT2LMHeadModel (OpenAI GPT-2 模型)
- idefics — TFIdeficsForVisionText2Text (IDEFICS 模型)
- layoutlm — TFLayoutLMForMaskedLM (LayoutLM 模型)
- lxmert — TFLxmertForPreTraining (LXMERT 模型)
- mobilebert — TFMobileBertForPreTraining (MobileBERT 模型)
- mpnet — TFMPNetForMaskedLM (MPNet 模型)
- openai-gpt — TFOpenAIGPTLMHeadModel (OpenAI GPT 模型)
- roberta — TFRobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- t5 — TFT5ForConditionalGeneration (T5 模型)
- tapas — TFTapasForMaskedLM (TAPAS 模型)
- transfo-xl — TFTransfoXLLMHeadModel (Transformer-XL 模型)
- vit_mae — TFViTMAEForPreTraining (ViTMAE 模型)
- xlm — TFXLMWithLMHeadModel (XLM 模型)
- xlm-roberta — TFXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- xlnet — TFXLNetLMHeadModel (XLNet 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForPreTraining.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForPreTraining
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有预训练头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: FlaxAlbertForPreTraining (ALBERT model)
- BartConfig configuration class: FlaxBartForConditionalGeneration (BART model)
- BertConfig configuration class: FlaxBertForPreTraining (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdForPreTraining (BigBird model)
- ElectraConfig configuration class: FlaxElectraForPreTraining (ELECTRA model)
- LongT5Config configuration class: FlaxLongT5ForConditionalGeneration (LongT5 model)
- MBartConfig configuration class: FlaxMBartForConditionalGeneration (mBART model)
- MT5Config configuration class: FlaxMT5ForConditionalGeneration (MT5 model)
- RoFormerConfig configuration class: FlaxRoFormerForMaskedLM (RoFormer model)
- RobertaConfig configuration class: FlaxRobertaForMaskedLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- T5Config configuration class: FlaxT5ForConditionalGeneration (T5 model)
- Wav2Vec2Config configuration class: FlaxWav2Vec2ForPreTraining (Wav2Vec2 model)
- WhisperConfig configuration class: FlaxWhisperForConditionalGeneration (Whisper model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaForMaskedLM (XLM-RoBERTa model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置实例化库中的一个模型类(带有预训练头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型中实例化库中的一个模型类(带有预训练头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — FlaxAlbertForPreTraining (ALBERT 模型)
- bart — FlaxBartForConditionalGeneration (BART 模型)
- bert — FlaxBertForPreTraining (BERT 模型)
- big_bird — FlaxBigBirdForPreTraining (BigBird 模型)
- electra — FlaxElectraForPreTraining (ELECTRA 模型)
- longt5 — FlaxLongT5ForConditionalGeneration (LongT5 模型)
- mbart — FlaxMBartForConditionalGeneration (mBART 模型)
- mt5 — FlaxMT5ForConditionalGeneration (MT5 模型)
- roberta — FlaxRobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForMaskedLM (RoFormer 模型)
- t5 — FlaxT5ForConditionalGeneration (T5 模型)
- wav2vec2 — FlaxWav2Vec2ForPreTraining (Wav2Vec2 模型)
- whisper — FlaxWhisperForConditionalGeneration (Whisper 模型)
- xlm-roberta — FlaxXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForPreTraining.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
自然语言处理
以下自动分类可用于以下自然语言处理任务。
AutoModelForCausalLM
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有因果语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: BartForCausalLM (BART model)
- BertConfig configuration class: BertLMHeadModel (BERT model)
- BertGenerationConfig configuration class: BertGenerationDecoder (Bert Generation model)
- BigBirdConfig configuration class: BigBirdForCausalLM (BigBird model)
- BigBirdPegasusConfig configuration class: BigBirdPegasusForCausalLM (BigBird-Pegasus model)
- BioGptConfig configuration class: BioGptForCausalLM (BioGpt model)
- BlenderbotConfig configuration class: BlenderbotForCausalLM (Blenderbot model)
- BlenderbotSmallConfig configuration class: BlenderbotSmallForCausalLM (BlenderbotSmall model)
- BloomConfig configuration class: BloomForCausalLM (BLOOM model)
- CTRLConfig configuration class: CTRLLMHeadModel (CTRL model)
- CamembertConfig configuration class: CamembertForCausalLM (CamemBERT model)
- CodeGenConfig configuration class: CodeGenForCausalLM (CodeGen model)
- CohereConfig configuration class: CohereForCausalLM (Cohere model)
- CpmAntConfig configuration class: CpmAntForCausalLM (CPM-Ant model)
- Data2VecTextConfig configuration class: Data2VecTextForCausalLM (Data2VecText model)
- DbrxConfig configuration class: DbrxForCausalLM (DBRX model)
- ElectraConfig configuration class: ElectraForCausalLM (ELECTRA model)
- ErnieConfig configuration class: ErnieForCausalLM (ERNIE model)
- FalconConfig configuration class: FalconForCausalLM (Falcon model)
- FalconMambaConfig configuration class: FalconMambaForCausalLM (FalconMamba model)
- FuyuConfig configuration class: FuyuForCausalLM (Fuyu model)
- GPT2Config configuration class: GPT2LMHeadModel (OpenAI GPT-2 model)
- GPTBigCodeConfig configuration class: GPTBigCodeForCausalLM (GPTBigCode model)
- GPTJConfig configuration class: GPTJForCausalLM (GPT-J model)
- GPTNeoConfig configuration class: GPTNeoForCausalLM (GPT Neo model)
- GPTNeoXConfig configuration class: GPTNeoXForCausalLM (GPT NeoX model)
- GPTNeoXJapaneseConfig configuration class: GPTNeoXJapaneseForCausalLM (GPT NeoX Japanese model)
- Gemma2Config configuration class: Gemma2ForCausalLM (Gemma2 model)
- GemmaConfig configuration class: GemmaForCausalLM (Gemma model)
- GitConfig configuration class: GitForCausalLM (GIT model)
- GlmConfig configuration class: GlmForCausalLM (GLM model)
- GraniteConfig configuration class: GraniteForCausalLM (Granite model)
- GraniteMoeConfig configuration class: GraniteMoeForCausalLM (GraniteMoeMoe model)
- JambaConfig configuration class: JambaForCausalLM (Jamba model)
- JetMoeConfig configuration class: JetMoeForCausalLM (JetMoe model)
- LlamaConfig configuration class: LlamaForCausalLM (LLaMA model)
- MBartConfig configuration class: MBartForCausalLM (mBART model)
- Mamba2Config configuration class: Mamba2ForCausalLM (mamba2 model)
- MambaConfig configuration class: MambaForCausalLM (Mamba model)
- MarianConfig configuration class: MarianForCausalLM (Marian model)
- MegaConfig configuration class: MegaForCausalLM (MEGA model)
- MegatronBertConfig configuration class: MegatronBertForCausalLM (Megatron-BERT model)
- MistralConfig configuration class: MistralForCausalLM (Mistral model)
- MixtralConfig configuration class: MixtralForCausalLM (Mixtral model)
- MllamaConfig configuration class: MllamaForCausalLM (Mllama model)
- MoshiConfig configuration class: MoshiForCausalLM (Moshi model)
- MptConfig configuration class: MptForCausalLM (MPT model)
- MusicgenConfig configuration class: MusicgenForCausalLM (MusicGen model)
- MusicgenMelodyConfig configuration class: MusicgenMelodyForCausalLM (MusicGen Melody model)
- MvpConfig configuration class: MvpForCausalLM (MVP model)
- NemotronConfig configuration class: NemotronForCausalLM (Nemotron model)
- OPTConfig configuration class: OPTForCausalLM (OPT model)
- Olmo2Config configuration class: Olmo2ForCausalLM (OLMo2 model)
- OlmoConfig configuration class: OlmoForCausalLM (OLMo model)
- OlmoeConfig configuration class: OlmoeForCausalLM (OLMoE model)
- OpenAIGPTConfig configuration class: OpenAIGPTLMHeadModel (OpenAI GPT model)
- OpenLlamaConfig configuration class: OpenLlamaForCausalLM (OpenLlama model)
- PLBartConfig configuration class: PLBartForCausalLM (PLBart model)
- PegasusConfig configuration class: PegasusForCausalLM (Pegasus model)
- PersimmonConfig configuration class: PersimmonForCausalLM (Persimmon model)
- Phi3Config configuration class: Phi3ForCausalLM (Phi3 model)
- PhiConfig configuration class: PhiForCausalLM (Phi model)
- PhimoeConfig configuration class: PhimoeForCausalLM (Phimoe model)
- ProphetNetConfig configuration class: ProphetNetForCausalLM (ProphetNet model)
- QDQBertConfig configuration class: QDQBertLMHeadModel (QDQBert model)
- Qwen2Config configuration class: Qwen2ForCausalLM (Qwen2 model)
- Qwen2MoeConfig configuration class: Qwen2MoeForCausalLM (Qwen2MoE model)
- RecurrentGemmaConfig configuration class: RecurrentGemmaForCausalLM (RecurrentGemma model)
- ReformerConfig configuration class: ReformerModelWithLMHead (Reformer model)
- RemBertConfig configuration class: RemBertForCausalLM (RemBERT model)
- RoCBertConfig configuration class: RoCBertForCausalLM (RoCBert model)
- RoFormerConfig configuration class: RoFormerForCausalLM (RoFormer model)
- RobertaConfig configuration class: RobertaForCausalLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm model)
- RwkvConfig configuration class: RwkvForCausalLM (RWKV model)
- Speech2Text2Config configuration class: Speech2Text2ForCausalLM (Speech2Text2 model)
- StableLmConfig configuration class: StableLmForCausalLM (StableLm model)
- Starcoder2Config configuration class: Starcoder2ForCausalLM (Starcoder2 model)
- TrOCRConfig configuration class: TrOCRForCausalLM (TrOCR model)
- TransfoXLConfig configuration class: TransfoXLLMHeadModel (Transformer-XL model)
- WhisperConfig configuration class: WhisperForCausalLM (Whisper model)
- XGLMConfig configuration class: XGLMForCausalLM (XGLM model)
- XLMConfig configuration class: XLMWithLMHeadModel (XLM model)
- XLMProphetNetConfig configuration class: XLMProphetNetForCausalLM (XLM-ProphetNet model)
- XLMRobertaConfig configuration class: XLMRobertaForCausalLM (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLForCausalLM (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetLMHeadModel (XLNet model)
- XmodConfig configuration class: XmodForCausalLM (X-MOD model)
- ZambaConfig configuration class: ZambaForCausalLM (Zamba model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有因果语言建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有因果语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- bart — BartForCausalLM (BART 模型)
- bert — BertLMHeadModel (BERT 模型)
- bert-generation — BertGenerationDecoder (Bert 生成模型)
- big_bird — BigBirdForCausalLM (BigBird 模型)
- bigbird_pegasus — BigBirdPegasusForCausalLM (BigBird-Pegasus 模型)
- biogpt — BioGptForCausalLM (BioGpt 模型)
- blenderbot — BlenderbotForCausalLM (Blenderbot 模型)
- blenderbot-small — BlenderbotSmallForCausalLM (BlenderbotSmall 模型)
- bloom — BloomForCausalLM (BLOOM 模型)
- camembert — CamembertForCausalLM (CamemBERT 模型)
- code_llama — LlamaForCausalLM (CodeLlama 模型)
- codegen — CodeGenForCausalLM (CodeGen 模型)
- cohere — CohereForCausalLM (Cohere 模型)
- cpmant — CpmAntForCausalLM (CPM-Ant 模型)
- ctrl — CTRLLMHeadModel (CTRL 模型)
- data2vec-text — Data2VecTextForCausalLM (Data2VecText 模型)
- dbrx — DbrxForCausalLM (DBRX 模型)
- electra — ElectraForCausalLM (ELECTRA 模型)
- ernie — ErnieForCausalLM (ERNIE 模型)
- falcon — FalconForCausalLM (Falcon 模型)
- falcon_mamba — FalconMambaForCausalLM (FalconMamba 模型)
- fuyu — FuyuForCausalLM (Fuyu 模型)
- gemma — GemmaForCausalLM (Gemma 模型)
- gemma2 — Gemma2ForCausalLM (Gemma2 模型)
- git — GitForCausalLM (GIT 模型)
- glm — GlmForCausalLM (GLM 模型)
- gpt-sw3 — GPT2LMHeadModel (GPT-Sw3 模型)
- gpt2 — GPT2LMHeadModel (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeForCausalLM (GPTBigCode 模型)
- gpt_neo — GPTNeoForCausalLM (GPT Neo 模型)
- gpt_neox — GPTNeoXForCausalLM (GPT NeoX 模型)
- gpt_neox_japanese — GPTNeoXJapaneseForCausalLM (GPT NeoX 日语模型)
- gptj — GPTJForCausalLM (GPT-J 模型)
- granite — GraniteForCausalLM (Granite 模型)
- granitemoe — GraniteMoeForCausalLM (GraniteMoeMoe 模型)
- jamba — JambaForCausalLM (Jamba 模型)
- jetmoe — JetMoeForCausalLM (JetMoe 模型)
- llama — LlamaForCausalLM (LLaMA 模型)
- mamba — MambaForCausalLM (Mamba 模型)
- mamba2 — Mamba2ForCausalLM (mamba2 模型)
- marian — MarianForCausalLM (Marian 模型)
- mbart — MBartForCausalLM (mBART 模型)
- mega — MegaForCausalLM (MEGA 模型)
- megatron-bert — MegatronBertForCausalLM (Megatron-BERT 模型)
- mistral — MistralForCausalLM (Mistral 模型)
- mixtral — MixtralForCausalLM (Mixtral 模型)
- mllama — MllamaForCausalLM (Mllama 模型)
- moshi — MoshiForCausalLM (Moshi 模型)
- mpt — MptForCausalLM (MPT 模型)
- musicgen — MusicgenForCausalLM (MusicGen 模型)
- musicgen_melody — MusicgenMelodyForCausalLM (MusicGen 旋律模型)
- mvp — MvpForCausalLM (MVP 模型)
- nemotron — NemotronForCausalLM (Nemotron 模型)
- olmo — OlmoForCausalLM (OLMo 模型)
- olmo2 — Olmo2ForCausalLM (OLMo2 模型)
- olmoe — OlmoeForCausalLM (OLMoE 模型)
- open-llama — OpenLlamaForCausalLM (OpenLlama 模型)
- openai-gpt — OpenAIGPTLMHeadModel (OpenAI GPT 模型)
- opt — OPTForCausalLM (OPT 模型)
- pegasus — PegasusForCausalLM (Pegasus 模型)
- 柿子 — PersimmonForCausalLM (柿子模型)
- phi — PhiForCausalLM (Phi 模型)
- phi3 — Phi3ForCausalLM (Phi3 模型)
- phimoe — PhimoeForCausalLM (Phimoe 模型)
- plbart — PLBartForCausalLM (PLBart 模型)
- prophetnet — ProphetNetForCausalLM (ProphetNet 模型)
- qdqbert — QDQBertLMHeadModel (QDQBert 模型)
- qwen2 — Qwen2ForCausalLM (Qwen2 模型)
- qwen2_moe — Qwen2MoeForCausalLM (Qwen2MoE 模型)
- recurrent_gemma — RecurrentGemmaForCausalLM (RecurrentGemma 模型)
- reformer — ReformerModelWithLMHead (Reformer 模型)
- rembert — RemBertForCausalLM (RemBERT 模型)
- roberta — RobertaForCausalLM (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForCausalLM (RoCBert 模型)
- roformer — RoFormerForCausalLM (RoFormer 模型)
- rwkv — RwkvForCausalLM (RWKV 模型)
- speech_to_text_2 — Speech2Text2ForCausalLM (Speech2Text2 模型)
- stablelm — StableLmForCausalLM (StableLm 模型)
- starcoder2 — Starcoder2ForCausalLM (Starcoder2 模型)
- transfo-xl — TransfoXLLMHeadModel (Transformer-XL 模型)
- trocr — TrOCRForCausalLM (TrOCR 模型)
- whisper — WhisperForCausalLM (Whisper 模型)
- xglm — XGLMForCausalLM (XGLM 模型)
- xlm — XLMWithLMHeadModel (XLM 模型)
- xlm-prophetnet — XLMProphetNetForCausalLM (XLM-ProphetNet 模型)
- xlm-roberta — XLMRobertaForCausalLM (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForCausalLM (XLM-RoBERTa-XL 模型)
- xlnet — XLNetLMHeadModel (XLNet 模型)
- xmod — XmodForCausalLM (X-MOD 模型)
- zamba — ZambaForCausalLM (Zamba 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForCausalLM.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForCausalLM
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有因果语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BertConfig configuration class: TFBertLMHeadModel (BERT model)
- CTRLConfig configuration class: TFCTRLLMHeadModel (CTRL model)
- CamembertConfig configuration class: TFCamembertForCausalLM (CamemBERT model)
- GPT2Config configuration class: TFGPT2LMHeadModel (OpenAI GPT-2 model)
- GPTJConfig configuration class: TFGPTJForCausalLM (GPT-J model)
- MistralConfig configuration class: TFMistralForCausalLM (Mistral model)
- OPTConfig configuration class: TFOPTForCausalLM (OPT model)
- OpenAIGPTConfig configuration class: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
- RemBertConfig configuration class: TFRemBertForCausalLM (RemBERT model)
- RoFormerConfig configuration class: TFRoFormerForCausalLM (RoFormer model)
- RobertaConfig configuration class: TFRobertaForCausalLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm model)
- TransfoXLConfig configuration class: TFTransfoXLLMHeadModel (Transformer-XL model)
- XGLMConfig configuration class: TFXGLMForCausalLM (XGLM model)
- XLMConfig configuration class: TFXLMWithLMHeadModel (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaForCausalLM (XLM-RoBERTa model)
- XLNetConfig configuration class: TFXLNetLMHeadModel (XLNet model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有因果语言建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有因果语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- bert — TFBertLMHeadModel (BERT 模型)
- camembert — TFCamembertForCausalLM (CamemBERT 模型)
- ctrl — TFCTRLLMHeadModel (CTRL 模型)
- gpt-sw3 — TFGPT2LMHeadModel (GPT-Sw3 模型)
- gpt2 — TFGPT2LMHeadModel (OpenAI GPT-2 模型)
- gptj — TFGPTJForCausalLM (GPT-J 模型)
- mistral — TFMistralForCausalLM (Mistral 模型)
- openai-gpt — TFOpenAIGPTLMHeadModel (OpenAI GPT 模型)
- opt — TFOPTForCausalLM (OPT 模型)
- rembert — TFRemBertForCausalLM (RemBERT 模型)
- roberta — TFRobertaForCausalLM (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForCausalLM (RoFormer 模型)
- transfo-xl — TFTransfoXLLMHeadModel (Transformer-XL 模型)
- xglm — TFXGLMForCausalLM (XGLM 模型)
- xlm — TFXLMWithLMHeadModel (XLM 模型)
- xlm-roberta — TFXLMRobertaForCausalLM (XLM-RoBERTa 模型)
- xlnet — TFXLNetLMHeadModel (XLNet 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForCausalLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForCausalLM
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有因果语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: FlaxBartForCausalLM (BART model)
- BertConfig configuration class: FlaxBertForCausalLM (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdForCausalLM (BigBird model)
- BloomConfig configuration class: FlaxBloomForCausalLM (BLOOM model)
- ElectraConfig configuration class: FlaxElectraForCausalLM (ELECTRA model)
- GPT2Config configuration class: FlaxGPT2LMHeadModel (OpenAI GPT-2 model)
- GPTJConfig configuration class: FlaxGPTJForCausalLM (GPT-J model)
- GPTNeoConfig configuration class: FlaxGPTNeoForCausalLM (GPT Neo model)
- GemmaConfig configuration class: FlaxGemmaForCausalLM (Gemma model)
- LlamaConfig configuration class: FlaxLlamaForCausalLM (LLaMA model)
- MistralConfig configuration class: FlaxMistralForCausalLM (Mistral model)
- OPTConfig configuration class: FlaxOPTForCausalLM (OPT model)
- RobertaConfig configuration class: FlaxRobertaForCausalLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm model)
- XGLMConfig configuration class: FlaxXGLMForCausalLM (XGLM model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaForCausalLM (XLM-RoBERTa model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有因果语言建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从PyTorch检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, optional, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有因果语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- bart — FlaxBartForCausalLM (BART 模型)
- bert — FlaxBertForCausalLM (BERT 模型)
- big_bird — FlaxBigBirdForCausalLM (BigBird 模型)
- bloom — FlaxBloomForCausalLM (BLOOM 模型)
- electra — FlaxElectraForCausalLM (ELECTRA 模型)
- gemma — FlaxGemmaForCausalLM (Gemma 模型)
- gpt-sw3 — FlaxGPT2LMHeadModel (GPT-Sw3 模型)
- gpt2 — FlaxGPT2LMHeadModel (OpenAI GPT-2 模型)
- gpt_neo — FlaxGPTNeoForCausalLM (GPT Neo 模型)
- gptj — FlaxGPTJForCausalLM (GPT-J 模型)
- llama — FlaxLlamaForCausalLM (LLaMA 模型)
- mistral — FlaxMistralForCausalLM (Mistral 模型)
- opt — FlaxOPTForCausalLM (OPT 模型)
- roberta — FlaxRobertaForCausalLM (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- xglm — FlaxXGLMForCausalLM (XGLM 模型)
- xlm-roberta — FlaxXLMRobertaForCausalLM (XLM-RoBERTa 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForCausalLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMaskedLM
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有掩码语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: AlbertForMaskedLM (ALBERT model)
- BartConfig configuration class: BartForConditionalGeneration (BART model)
- BertConfig configuration class: BertForMaskedLM (BERT model)
- BigBirdConfig configuration class: BigBirdForMaskedLM (BigBird model)
- CamembertConfig configuration class: CamembertForMaskedLM (CamemBERT model)
- ConvBertConfig configuration class: ConvBertForMaskedLM (ConvBERT model)
- Data2VecTextConfig configuration class: Data2VecTextForMaskedLM (Data2VecText model)
- DebertaConfig configuration class: DebertaForMaskedLM (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForMaskedLM (DeBERTa-v2 model)
- DistilBertConfig configuration class: DistilBertForMaskedLM (DistilBERT model)
- ElectraConfig configuration class: ElectraForMaskedLM (ELECTRA model)
- ErnieConfig configuration class: ErnieForMaskedLM (ERNIE model)
- EsmConfig configuration class: EsmForMaskedLM (ESM model)
- FNetConfig configuration class: FNetForMaskedLM (FNet model)
- FlaubertConfig configuration class: FlaubertWithLMHeadModel (FlauBERT model)
- FunnelConfig configuration class: FunnelForMaskedLM (Funnel Transformer model)
- IBertConfig configuration class: IBertForMaskedLM (I-BERT model)
- LayoutLMConfig configuration class: LayoutLMForMaskedLM (LayoutLM model)
- LongformerConfig configuration class: LongformerForMaskedLM (Longformer model)
- LukeConfig configuration class: LukeForMaskedLM (LUKE model)
- MBartConfig configuration class: MBartForConditionalGeneration (mBART model)
- MPNetConfig configuration class: MPNetForMaskedLM (MPNet model)
- MegaConfig configuration class: MegaForMaskedLM (MEGA model)
- MegatronBertConfig configuration class: MegatronBertForMaskedLM (Megatron-BERT model)
- MobileBertConfig configuration class: MobileBertForMaskedLM (MobileBERT model)
- MraConfig configuration class: MraForMaskedLM (MRA model)
- MvpConfig configuration class: MvpForConditionalGeneration (MVP model)
- NezhaConfig configuration class: NezhaForMaskedLM (Nezha model)
- NystromformerConfig configuration class: NystromformerForMaskedLM (Nyströmformer model)
- PerceiverConfig configuration class: PerceiverForMaskedLM (Perceiver model)
- QDQBertConfig configuration class: QDQBertForMaskedLM (QDQBert model)
- ReformerConfig configuration class: ReformerForMaskedLM (Reformer model)
- RemBertConfig configuration class: RemBertForMaskedLM (RemBERT model)
- RoCBertConfig configuration class: RoCBertForMaskedLM (RoCBert model)
- RoFormerConfig configuration class: RoFormerForMaskedLM (RoFormer model)
- RobertaConfig configuration class: RobertaForMaskedLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- SqueezeBertConfig configuration class: SqueezeBertForMaskedLM (SqueezeBERT model)
- TapasConfig configuration class: TapasForMaskedLM (TAPAS model)
- Wav2Vec2Config configuration class:
Wav2Vec2ForMaskedLM
(Wav2Vec2 model) - XLMConfig configuration class: XLMWithLMHeadModel (XLM model)
- XLMRobertaConfig configuration class: XLMRobertaForMaskedLM (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL model)
- XmodConfig configuration class: XmodForMaskedLM (X-MOD model)
- YosoConfig configuration class: YosoForMaskedLM (YOSO model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有掩码语言建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有掩码语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — AlbertForMaskedLM (ALBERT 模型)
- bart — BartForConditionalGeneration (BART 模型)
- bert — BertForMaskedLM (BERT 模型)
- big_bird — BigBirdForMaskedLM (BigBird 模型)
- camembert — CamembertForMaskedLM (CamemBERT 模型)
- convbert — ConvBertForMaskedLM (ConvBERT 模型)
- data2vec-text — Data2VecTextForMaskedLM (Data2VecText 模型)
- deberta — DebertaForMaskedLM (DeBERTa 模型)
- deberta-v2 — DebertaV2ForMaskedLM (DeBERTa-v2 模型)
- distilbert — DistilBertForMaskedLM (DistilBERT 模型)
- electra — ElectraForMaskedLM (ELECTRA 模型)
- ernie — ErnieForMaskedLM (ERNIE 模型)
- esm — EsmForMaskedLM (ESM 模型)
- flaubert — FlaubertWithLMHeadModel (FlauBERT 模型)
- fnet — FNetForMaskedLM (FNet 模型)
- funnel — FunnelForMaskedLM (漏斗变压器模型)
- ibert — IBertForMaskedLM (I-BERT 模型)
- layoutlm — LayoutLMForMaskedLM (LayoutLM 模型)
- longformer — LongformerForMaskedLM (Longformer 模型)
- luke — LukeForMaskedLM (LUKE 模型)
- mbart — MBartForConditionalGeneration (mBART 模型)
- mega — MegaForMaskedLM (MEGA 模型)
- megatron-bert — MegatronBertForMaskedLM (Megatron-BERT 模型)
- mobilebert — MobileBertForMaskedLM (MobileBERT 模型)
- mpnet — MPNetForMaskedLM (MPNet 模型)
- mra — MraForMaskedLM (MRA 模型)
- mvp — MvpForConditionalGeneration (MVP 模型)
- 哪吒 — NezhaForMaskedLM (哪吒模型)
- nystromformer — NystromformerForMaskedLM (Nyströmformer 模型)
- perceiver — PerceiverForMaskedLM (Perceiver 模型)
- qdqbert — QDQBertForMaskedLM (QDQBert 模型)
- reformer — ReformerForMaskedLM (Reformer 模型)
- rembert — RemBertForMaskedLM (RemBERT 模型)
- roberta — RobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForMaskedLM (RoCBert 模型)
- roformer — RoFormerForMaskedLM (RoFormer 模型)
- squeezebert — SqueezeBertForMaskedLM (SqueezeBERT 模型)
- tapas — TapasForMaskedLM (TAPAS 模型)
- wav2vec2 —
Wav2Vec2ForMaskedLM
(Wav2Vec2 模型) - xlm — XLMWithLMHeadModel (XLM 模型)
- xlm-roberta — XLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL 模型)
- xmod — XmodForMaskedLM (X-MOD 模型)
- yoso — YosoForMaskedLM (YOSO 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMaskedLM.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMaskedLM
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有掩码语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: TFAlbertForMaskedLM (ALBERT model)
- BertConfig configuration class: TFBertForMaskedLM (BERT model)
- CamembertConfig configuration class: TFCamembertForMaskedLM (CamemBERT model)
- ConvBertConfig configuration class: TFConvBertForMaskedLM (ConvBERT model)
- DebertaConfig configuration class: TFDebertaForMaskedLM (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2ForMaskedLM (DeBERTa-v2 model)
- DistilBertConfig configuration class: TFDistilBertForMaskedLM (DistilBERT model)
- ElectraConfig configuration class: TFElectraForMaskedLM (ELECTRA model)
- EsmConfig configuration class: TFEsmForMaskedLM (ESM model)
- FlaubertConfig configuration class: TFFlaubertWithLMHeadModel (FlauBERT model)
- FunnelConfig configuration class: TFFunnelForMaskedLM (Funnel Transformer model)
- LayoutLMConfig configuration class: TFLayoutLMForMaskedLM (LayoutLM model)
- LongformerConfig configuration class: TFLongformerForMaskedLM (Longformer model)
- MPNetConfig configuration class: TFMPNetForMaskedLM (MPNet model)
- MobileBertConfig configuration class: TFMobileBertForMaskedLM (MobileBERT model)
- RemBertConfig configuration class: TFRemBertForMaskedLM (RemBERT model)
- RoFormerConfig configuration class: TFRoFormerForMaskedLM (RoFormer model)
- RobertaConfig configuration class: TFRobertaForMaskedLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- TapasConfig configuration class: TFTapasForMaskedLM (TAPAS model)
- XLMConfig configuration class: TFXLMWithLMHeadModel (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaForMaskedLM (XLM-RoBERTa model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有掩码语言建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的存储库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有掩码语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — TFAlbertForMaskedLM (ALBERT 模型)
- bert — TFBertForMaskedLM (BERT 模型)
- camembert — TFCamembertForMaskedLM (CamemBERT 模型)
- convbert — TFConvBertForMaskedLM (ConvBERT 模型)
- deberta — TFDebertaForMaskedLM (DeBERTa 模型)
- deberta-v2 — TFDebertaV2ForMaskedLM (DeBERTa-v2 模型)
- distilbert — TFDistilBertForMaskedLM (DistilBERT 模型)
- electra — TFElectraForMaskedLM (ELECTRA 模型)
- esm — TFEsmForMaskedLM (ESM 模型)
- flaubert — TFFlaubertWithLMHeadModel (FlauBERT 模型)
- funnel — TFFunnelForMaskedLM (漏斗变换器模型)
- layoutlm — TFLayoutLMForMaskedLM (LayoutLM 模型)
- longformer — TFLongformerForMaskedLM (Longformer 模型)
- mobilebert — TFMobileBertForMaskedLM (MobileBERT 模型)
- mpnet — TFMPNetForMaskedLM (MPNet 模型)
- rembert — TFRemBertForMaskedLM (RemBERT 模型)
- roberta — TFRobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForMaskedLM (RoFormer 模型)
- tapas — TFTapasForMaskedLM (TAPAS 模型)
- xlm — TFXLMWithLMHeadModel (XLM 模型)
- xlm-roberta — TFXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForMaskedLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForMaskedLM
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有掩码语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: FlaxAlbertForMaskedLM (ALBERT model)
- BartConfig configuration class: FlaxBartForConditionalGeneration (BART model)
- BertConfig configuration class: FlaxBertForMaskedLM (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdForMaskedLM (BigBird model)
- DistilBertConfig configuration class: FlaxDistilBertForMaskedLM (DistilBERT model)
- ElectraConfig configuration class: FlaxElectraForMaskedLM (ELECTRA model)
- MBartConfig configuration class: FlaxMBartForConditionalGeneration (mBART model)
- RoFormerConfig configuration class: FlaxRoFormerForMaskedLM (RoFormer model)
- RobertaConfig configuration class: FlaxRobertaForMaskedLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaForMaskedLM (XLM-RoBERTa model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有掩码语言建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在所有下载在可能的情况下默认都会恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型文件定义的模型。此选项应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,因此revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有掩码语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — FlaxAlbertForMaskedLM (ALBERT 模型)
- bart — FlaxBartForConditionalGeneration (BART 模型)
- bert — FlaxBertForMaskedLM (BERT 模型)
- big_bird — FlaxBigBirdForMaskedLM (BigBird 模型)
- distilbert — FlaxDistilBertForMaskedLM (DistilBERT 模型)
- electra — FlaxElectraForMaskedLM (ELECTRA 模型)
- mbart — FlaxMBartForConditionalGeneration (mBART 模型)
- roberta — FlaxRobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForMaskedLM (RoFormer 模型)
- xlm-roberta — FlaxXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForMaskedLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMaskGeneration
TFAutoModelForMaskGeneration
AutoModelForSeq2SeqLM
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有序列到序列语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: BartForConditionalGeneration (BART model)
- BigBirdPegasusConfig configuration class: BigBirdPegasusForConditionalGeneration (BigBird-Pegasus model)
- BlenderbotConfig configuration class: BlenderbotForConditionalGeneration (Blenderbot model)
- BlenderbotSmallConfig configuration class: BlenderbotSmallForConditionalGeneration (BlenderbotSmall model)
- EncoderDecoderConfig configuration class: EncoderDecoderModel (Encoder decoder model)
- FSMTConfig configuration class: FSMTForConditionalGeneration (FairSeq Machine-Translation model)
- GPTSanJapaneseConfig configuration class: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese model)
- LEDConfig configuration class: LEDForConditionalGeneration (LED model)
- LongT5Config configuration class: LongT5ForConditionalGeneration (LongT5 model)
- M2M100Config configuration class: M2M100ForConditionalGeneration (M2M100 model)
- MBartConfig configuration class: MBartForConditionalGeneration (mBART model)
- MT5Config configuration class: MT5ForConditionalGeneration (MT5 model)
- MarianConfig configuration class: MarianMTModel (Marian model)
- MvpConfig configuration class: MvpForConditionalGeneration (MVP model)
- NllbMoeConfig configuration class: NllbMoeForConditionalGeneration (NLLB-MOE model)
- PLBartConfig configuration class: PLBartForConditionalGeneration (PLBart model)
- PegasusConfig configuration class: PegasusForConditionalGeneration (Pegasus model)
- PegasusXConfig configuration class: PegasusXForConditionalGeneration (PEGASUS-X model)
- ProphetNetConfig configuration class: ProphetNetForConditionalGeneration (ProphetNet model)
- Qwen2AudioConfig configuration class: Qwen2AudioForConditionalGeneration (Qwen2Audio model)
- SeamlessM4TConfig configuration class: SeamlessM4TForTextToText (SeamlessM4T model)
- SeamlessM4Tv2Config configuration class: SeamlessM4Tv2ForTextToText (SeamlessM4Tv2 model)
- SwitchTransformersConfig configuration class: SwitchTransformersForConditionalGeneration (SwitchTransformers model)
- T5Config configuration class: T5ForConditionalGeneration (T5 model)
- UMT5Config configuration class: UMT5ForConditionalGeneration (UMT5 model)
- XLMProphetNetConfig configuration class: XLMProphetNetForConditionalGeneration (XLM-ProphetNet model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有序列到序列语言建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型文件定义的模型。此选项应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- bart — BartForConditionalGeneration (BART 模型)
- bigbird_pegasus — BigBirdPegasusForConditionalGeneration (BigBird-Pegasus 模型)
- blenderbot — BlenderbotForConditionalGeneration (Blenderbot 模型)
- blenderbot-small — BlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- 编码器-解码器 — EncoderDecoderModel (编码器解码器模型)
- fsmt — FSMTForConditionalGeneration (FairSeq 机器翻译模型)
- gptsan-japanese — GPTSanJapaneseForConditionalGeneration (GPTSAN-日语模型)
- led — LEDForConditionalGeneration (LED 模型)
- longt5 — LongT5ForConditionalGeneration (LongT5 模型)
- m2m_100 — M2M100ForConditionalGeneration (M2M100 模型)
- marian — MarianMTModel (Marian 模型)
- mbart — MBartForConditionalGeneration (mBART 模型)
- mt5 — MT5ForConditionalGeneration (MT5 模型)
- mvp — MvpForConditionalGeneration (MVP 模型)
- nllb-moe — NllbMoeForConditionalGeneration (NLLB-MOE 模型)
- pegasus — PegasusForConditionalGeneration (Pegasus 模型)
- pegasus_x — PegasusXForConditionalGeneration (PEGASUS-X 模型)
- plbart — PLBartForConditionalGeneration (PLBart 模型)
- prophetnet — ProphetNetForConditionalGeneration (ProphetNet 模型)
- qwen2_audio — Qwen2AudioForConditionalGeneration (Qwen2Audio 模型)
- seamless_m4t — SeamlessM4TForTextToText (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4Tv2ForTextToText (SeamlessM4Tv2 模型)
- switch_transformers — SwitchTransformersForConditionalGeneration (SwitchTransformers 模型)
- t5 — T5ForConditionalGeneration (T5 模型)
- umt5 — UMT5ForConditionalGeneration (UMT5 模型)
- xlm-prophetnet — XLMProphetNetForConditionalGeneration (XLM-ProphetNet 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # Update configuration during loading
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/t5_tf_model_config.json")
>>> model = AutoModelForSeq2SeqLM.from_pretrained(
... "./tf_model/t5_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSeq2SeqLM
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有序列到序列语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: TFBartForConditionalGeneration (BART model)
- BlenderbotConfig configuration class: TFBlenderbotForConditionalGeneration (Blenderbot model)
- BlenderbotSmallConfig configuration class: TFBlenderbotSmallForConditionalGeneration (BlenderbotSmall model)
- EncoderDecoderConfig configuration class: TFEncoderDecoderModel (Encoder decoder model)
- LEDConfig configuration class: TFLEDForConditionalGeneration (LED model)
- MBartConfig configuration class: TFMBartForConditionalGeneration (mBART model)
- MT5Config configuration class: TFMT5ForConditionalGeneration (MT5 model)
- MarianConfig configuration class: TFMarianMTModel (Marian model)
- PegasusConfig configuration class: TFPegasusForConditionalGeneration (Pegasus model)
- T5Config configuration class: TFT5ForConditionalGeneration (T5 model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有序列到序列语言建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能时都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的存储库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- bart — TFBartForConditionalGeneration (BART 模型)
- blenderbot — TFBlenderbotForConditionalGeneration (Blenderbot 模型)
- blenderbot-small — TFBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- encoder-decoder — TFEncoderDecoderModel (编码器解码器模型)
- led — TFLEDForConditionalGeneration (LED 模型)
- marian — TFMarianMTModel (Marian 模型)
- mbart — TFMBartForConditionalGeneration (mBART 模型)
- mt5 — TFMT5ForConditionalGeneration (MT5 模型)
- pegasus — TFPegasusForConditionalGeneration (Pegasus 模型)
- t5 — TFT5ForConditionalGeneration (T5 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # Update configuration during loading
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/t5_pt_model_config.json")
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained(
... "./pt_model/t5_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSeq2SeqLM
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有序列到序列语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: FlaxBartForConditionalGeneration (BART model)
- BlenderbotConfig configuration class: FlaxBlenderbotForConditionalGeneration (Blenderbot model)
- BlenderbotSmallConfig configuration class: FlaxBlenderbotSmallForConditionalGeneration (BlenderbotSmall model)
- EncoderDecoderConfig configuration class: FlaxEncoderDecoderModel (Encoder decoder model)
- LongT5Config configuration class: FlaxLongT5ForConditionalGeneration (LongT5 model)
- MBartConfig configuration class: FlaxMBartForConditionalGeneration (mBART model)
- MT5Config configuration class: FlaxMT5ForConditionalGeneration (MT5 model)
- MarianConfig configuration class: FlaxMarianMTModel (Marian model)
- PegasusConfig configuration class: FlaxPegasusForConditionalGeneration (Pegasus model)
- T5Config configuration class: FlaxT5ForConditionalGeneration (T5 model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有序列到序列语言建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载都会在可能的情况下恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在 Hub 上使用的特定代码版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- bart — FlaxBartForConditionalGeneration (BART 模型)
- blenderbot — FlaxBlenderbotForConditionalGeneration (Blenderbot 模型)
- blenderbot-small — FlaxBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- 编码器-解码器 — FlaxEncoderDecoderModel (编码器解码器模型)
- longt5 — FlaxLongT5ForConditionalGeneration (LongT5 模型)
- marian — FlaxMarianMTModel (Marian 模型)
- mbart — FlaxMBartForConditionalGeneration (mBART 模型)
- mt5 — FlaxMT5ForConditionalGeneration (MT5 模型)
- pegasus — FlaxPegasusForConditionalGeneration (Pegasus 模型)
- t5 — FlaxT5ForConditionalGeneration (T5 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/t5_pt_model_config.json")
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
... "./pt_model/t5_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForSequenceClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有序列分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: AlbertForSequenceClassification (ALBERT model)
- BartConfig configuration class: BartForSequenceClassification (BART model)
- BertConfig configuration class: BertForSequenceClassification (BERT model)
- BigBirdConfig configuration class: BigBirdForSequenceClassification (BigBird model)
- BigBirdPegasusConfig configuration class: BigBirdPegasusForSequenceClassification (BigBird-Pegasus model)
- BioGptConfig configuration class: BioGptForSequenceClassification (BioGpt model)
- BloomConfig configuration class: BloomForSequenceClassification (BLOOM model)
- CTRLConfig configuration class: CTRLForSequenceClassification (CTRL model)
- CamembertConfig configuration class: CamembertForSequenceClassification (CamemBERT model)
- CanineConfig configuration class: CanineForSequenceClassification (CANINE model)
- ConvBertConfig configuration class: ConvBertForSequenceClassification (ConvBERT model)
- Data2VecTextConfig configuration class: Data2VecTextForSequenceClassification (Data2VecText model)
- DebertaConfig configuration class: DebertaForSequenceClassification (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForSequenceClassification (DeBERTa-v2 model)
- DistilBertConfig configuration class: DistilBertForSequenceClassification (DistilBERT model)
- ElectraConfig configuration class: ElectraForSequenceClassification (ELECTRA model)
- ErnieConfig configuration class: ErnieForSequenceClassification (ERNIE model)
- ErnieMConfig configuration class: ErnieMForSequenceClassification (ErnieM model)
- EsmConfig configuration class: EsmForSequenceClassification (ESM model)
- FNetConfig configuration class: FNetForSequenceClassification (FNet model)
- FalconConfig configuration class: FalconForSequenceClassification (Falcon model)
- FlaubertConfig configuration class: FlaubertForSequenceClassification (FlauBERT model)
- FunnelConfig configuration class: FunnelForSequenceClassification (Funnel Transformer model)
- GPT2Config configuration class: GPT2ForSequenceClassification (OpenAI GPT-2 model)
- GPTBigCodeConfig configuration class: GPTBigCodeForSequenceClassification (GPTBigCode model)
- GPTJConfig configuration class: GPTJForSequenceClassification (GPT-J model)
- GPTNeoConfig configuration class: GPTNeoForSequenceClassification (GPT Neo model)
- GPTNeoXConfig configuration class: GPTNeoXForSequenceClassification (GPT NeoX model)
- Gemma2Config configuration class: Gemma2ForSequenceClassification (Gemma2 model)
- GemmaConfig configuration class: GemmaForSequenceClassification (Gemma model)
- GlmConfig configuration class: GlmForSequenceClassification (GLM model)
- IBertConfig configuration class: IBertForSequenceClassification (I-BERT model)
- JambaConfig configuration class: JambaForSequenceClassification (Jamba model)
- JetMoeConfig configuration class: JetMoeForSequenceClassification (JetMoe model)
- LEDConfig configuration class: LEDForSequenceClassification (LED model)
- LayoutLMConfig configuration class: LayoutLMForSequenceClassification (LayoutLM model)
- LayoutLMv2Config configuration class: LayoutLMv2ForSequenceClassification (LayoutLMv2 model)
- LayoutLMv3Config configuration class: LayoutLMv3ForSequenceClassification (LayoutLMv3 model)
- LiltConfig configuration class: LiltForSequenceClassification (LiLT model)
- LlamaConfig configuration class: LlamaForSequenceClassification (LLaMA model)
- LongformerConfig configuration class: LongformerForSequenceClassification (Longformer model)
- LukeConfig configuration class: LukeForSequenceClassification (LUKE model)
- MBartConfig configuration class: MBartForSequenceClassification (mBART model)
- MPNetConfig configuration class: MPNetForSequenceClassification (MPNet model)
- MT5Config configuration class: MT5ForSequenceClassification (MT5 model)
- MarkupLMConfig configuration class: MarkupLMForSequenceClassification (MarkupLM model)
- MegaConfig configuration class: MegaForSequenceClassification (MEGA model)
- MegatronBertConfig configuration class: MegatronBertForSequenceClassification (Megatron-BERT model)
- MistralConfig configuration class: MistralForSequenceClassification (Mistral model)
- MixtralConfig configuration class: MixtralForSequenceClassification (Mixtral model)
- MobileBertConfig configuration class: MobileBertForSequenceClassification (MobileBERT model)
- MptConfig configuration class: MptForSequenceClassification (MPT model)
- MraConfig configuration class: MraForSequenceClassification (MRA model)
- MvpConfig configuration class: MvpForSequenceClassification (MVP model)
- NemotronConfig configuration class: NemotronForSequenceClassification (Nemotron model)
- NezhaConfig configuration class: NezhaForSequenceClassification (Nezha model)
- NystromformerConfig configuration class: NystromformerForSequenceClassification (Nyströmformer model)
- OPTConfig configuration class: OPTForSequenceClassification (OPT model)
- OpenAIGPTConfig configuration class: OpenAIGPTForSequenceClassification (OpenAI GPT model)
- OpenLlamaConfig configuration class: OpenLlamaForSequenceClassification (OpenLlama model)
- PLBartConfig configuration class: PLBartForSequenceClassification (PLBart model)
- PerceiverConfig configuration class: PerceiverForSequenceClassification (Perceiver model)
- PersimmonConfig configuration class: PersimmonForSequenceClassification (Persimmon model)
- Phi3Config configuration class: Phi3ForSequenceClassification (Phi3 model)
- PhiConfig configuration class: PhiForSequenceClassification (Phi model)
- PhimoeConfig configuration class: PhimoeForSequenceClassification (Phimoe model)
- QDQBertConfig configuration class: QDQBertForSequenceClassification (QDQBert model)
- Qwen2Config configuration class: Qwen2ForSequenceClassification (Qwen2 model)
- Qwen2MoeConfig configuration class: Qwen2MoeForSequenceClassification (Qwen2MoE model)
- ReformerConfig configuration class: ReformerForSequenceClassification (Reformer model)
- RemBertConfig configuration class: RemBertForSequenceClassification (RemBERT model)
- RoCBertConfig configuration class: RoCBertForSequenceClassification (RoCBert model)
- RoFormerConfig configuration class: RoFormerForSequenceClassification (RoFormer model)
- RobertaConfig configuration class: RobertaForSequenceClassification (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm model)
- SqueezeBertConfig configuration class: SqueezeBertForSequenceClassification (SqueezeBERT model)
- StableLmConfig configuration class: StableLmForSequenceClassification (StableLm model)
- Starcoder2Config configuration class: Starcoder2ForSequenceClassification (Starcoder2 model)
- T5Config configuration class: T5ForSequenceClassification (T5 model)
- TapasConfig configuration class: TapasForSequenceClassification (TAPAS model)
- TransfoXLConfig configuration class: TransfoXLForSequenceClassification (Transformer-XL model)
- UMT5Config configuration class: UMT5ForSequenceClassification (UMT5 model)
- XLMConfig configuration class: XLMForSequenceClassification (XLM model)
- XLMRobertaConfig configuration class: XLMRobertaForSequenceClassification (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLForSequenceClassification (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetForSequenceClassification (XLNet model)
- XmodConfig configuration class: XmodForSequenceClassification (X-MOD model)
- YosoConfig configuration class: YosoForSequenceClassification (YOSO model)
- ZambaConfig configuration class: ZambaForSequenceClassification (Zamba model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有序列分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在您的本地机器上执行Hub上的代码。 - code_revision (
str
, optional, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有序列分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — AlbertForSequenceClassification (ALBERT 模型)
- bart — BartForSequenceClassification (BART 模型)
- bert — BertForSequenceClassification (BERT 模型)
- big_bird — BigBirdForSequenceClassification (BigBird 模型)
- bigbird_pegasus — BigBirdPegasusForSequenceClassification (BigBird-Pegasus 模型)
- biogpt — BioGptForSequenceClassification (BioGpt 模型)
- bloom — BloomForSequenceClassification (BLOOM 模型)
- camembert — CamembertForSequenceClassification (CamemBERT 模型)
- 犬类 — CanineForSequenceClassification (CANINE 模型)
- code_llama — LlamaForSequenceClassification (CodeLlama 模型)
- convbert — ConvBertForSequenceClassification (ConvBERT 模型)
- ctrl — CTRLForSequenceClassification (CTRL 模型)
- data2vec-text — Data2VecTextForSequenceClassification (Data2VecText 模型)
- deberta — DebertaForSequenceClassification (DeBERTa 模型)
- deberta-v2 — DebertaV2ForSequenceClassification (DeBERTa-v2 模型)
- distilbert — DistilBertForSequenceClassification (DistilBERT 模型)
- electra — ElectraForSequenceClassification (ELECTRA 模型)
- ernie — ErnieForSequenceClassification (ERNIE 模型)
- ernie_m — ErnieMForSequenceClassification (ErnieM 模型)
- esm — EsmForSequenceClassification (ESM 模型)
- falcon — FalconForSequenceClassification (Falcon 模型)
- flaubert — FlaubertForSequenceClassification (FlauBERT 模型)
- fnet — FNetForSequenceClassification (FNet 模型)
- funnel — FunnelForSequenceClassification (漏斗变压器模型)
- gemma — GemmaForSequenceClassification (Gemma 模型)
- gemma2 — Gemma2ForSequenceClassification (Gemma2 模型)
- glm — GlmForSequenceClassification (GLM 模型)
- gpt-sw3 — GPT2ForSequenceClassification (GPT-Sw3 模型)
- gpt2 — GPT2ForSequenceClassification (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeForSequenceClassification (GPTBigCode 模型)
- gpt_neo — GPTNeoForSequenceClassification (GPT Neo 模型)
- gpt_neox — GPTNeoXForSequenceClassification (GPT NeoX 模型)
- gptj — GPTJForSequenceClassification (GPT-J 模型)
- ibert — IBertForSequenceClassification (I-BERT 模型)
- jamba — JambaForSequenceClassification (Jamba 模型)
- jetmoe — JetMoeForSequenceClassification (JetMoe 模型)
- layoutlm — LayoutLMForSequenceClassification (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2ForSequenceClassification (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3ForSequenceClassification (LayoutLMv3 模型)
- led — LEDForSequenceClassification (LED 模型)
- lilt — LiltForSequenceClassification (LiLT 模型)
- llama — LlamaForSequenceClassification (LLaMA 模型)
- longformer — LongformerForSequenceClassification (Longformer 模型)
- luke — LukeForSequenceClassification (LUKE 模型)
- markuplm — MarkupLMForSequenceClassification (MarkupLM 模型)
- mbart — MBartForSequenceClassification (mBART 模型)
- mega — MegaForSequenceClassification (MEGA 模型)
- megatron-bert — MegatronBertForSequenceClassification (Megatron-BERT 模型)
- mistral — MistralForSequenceClassification (Mistral 模型)
- mixtral — MixtralForSequenceClassification (Mixtral 模型)
- mobilebert — MobileBertForSequenceClassification (MobileBERT 模型)
- mpnet — MPNetForSequenceClassification (MPNet 模型)
- mpt — MptForSequenceClassification (MPT 模型)
- mra — MraForSequenceClassification (MRA 模型)
- mt5 — MT5ForSequenceClassification (MT5 模型)
- mvp — MvpForSequenceClassification (MVP 模型)
- nemotron — NemotronForSequenceClassification (Nemotron 模型)
- 哪吒 — NezhaForSequenceClassification (哪吒模型)
- nystromformer — NystromformerForSequenceClassification (Nyströmformer 模型)
- open-llama — OpenLlamaForSequenceClassification (OpenLlama 模型)
- openai-gpt — OpenAIGPTForSequenceClassification (OpenAI GPT 模型)
- opt — OPTForSequenceClassification (OPT 模型)
- perceiver — PerceiverForSequenceClassification (Perceiver 模型)
- 柿子 — PersimmonForSequenceClassification (柿子模型)
- phi — PhiForSequenceClassification (Phi 模型)
- phi3 — Phi3ForSequenceClassification (Phi3 模型)
- phimoe — PhimoeForSequenceClassification (Phimoe 模型)
- plbart — PLBartForSequenceClassification (PLBart 模型)
- qdqbert — QDQBertForSequenceClassification (QDQBert 模型)
- qwen2 — Qwen2ForSequenceClassification (Qwen2 模型)
- qwen2_moe — Qwen2MoeForSequenceClassification (Qwen2MoE 模型)
- reformer — ReformerForSequenceClassification (Reformer 模型)
- rembert — RemBertForSequenceClassification (RemBERT 模型)
- roberta — RobertaForSequenceClassification (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForSequenceClassification (RoCBert 模型)
- roformer — RoFormerForSequenceClassification (RoFormer 模型)
- squeezebert — SqueezeBertForSequenceClassification (SqueezeBERT 模型)
- stablelm — StableLmForSequenceClassification (StableLm 模型)
- starcoder2 — Starcoder2ForSequenceClassification (Starcoder2 模型)
- t5 — T5ForSequenceClassification (T5 模型)
- tapas — TapasForSequenceClassification (TAPAS 模型)
- transfo-xl — TransfoXLForSequenceClassification (Transformer-XL 模型)
- umt5 — UMT5ForSequenceClassification (UMT5 模型)
- xlm — XLMForSequenceClassification (XLM 模型)
- xlm-roberta — XLMRobertaForSequenceClassification (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForSequenceClassification (XLM-RoBERTa-XL 模型)
- xlnet — XLNetForSequenceClassification (XLNet 模型)
- xmod — XmodForSequenceClassification (X-MOD 模型)
- yoso — YosoForSequenceClassification (YOSO 模型)
- zamba — ZambaForSequenceClassification (Zamba 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForSequenceClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSequenceClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有序列分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: TFAlbertForSequenceClassification (ALBERT model)
- BartConfig configuration class: TFBartForSequenceClassification (BART model)
- BertConfig configuration class: TFBertForSequenceClassification (BERT model)
- CTRLConfig configuration class: TFCTRLForSequenceClassification (CTRL model)
- CamembertConfig configuration class: TFCamembertForSequenceClassification (CamemBERT model)
- ConvBertConfig configuration class: TFConvBertForSequenceClassification (ConvBERT model)
- DebertaConfig configuration class: TFDebertaForSequenceClassification (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2ForSequenceClassification (DeBERTa-v2 model)
- DistilBertConfig configuration class: TFDistilBertForSequenceClassification (DistilBERT model)
- ElectraConfig configuration class: TFElectraForSequenceClassification (ELECTRA model)
- EsmConfig configuration class: TFEsmForSequenceClassification (ESM model)
- FlaubertConfig configuration class: TFFlaubertForSequenceClassification (FlauBERT model)
- FunnelConfig configuration class: TFFunnelForSequenceClassification (Funnel Transformer model)
- GPT2Config configuration class: TFGPT2ForSequenceClassification (OpenAI GPT-2 model)
- GPTJConfig configuration class: TFGPTJForSequenceClassification (GPT-J model)
- LayoutLMConfig configuration class: TFLayoutLMForSequenceClassification (LayoutLM model)
- LayoutLMv3Config configuration class: TFLayoutLMv3ForSequenceClassification (LayoutLMv3 model)
- LongformerConfig configuration class: TFLongformerForSequenceClassification (Longformer model)
- MPNetConfig configuration class: TFMPNetForSequenceClassification (MPNet model)
- MistralConfig configuration class: TFMistralForSequenceClassification (Mistral model)
- MobileBertConfig configuration class: TFMobileBertForSequenceClassification (MobileBERT model)
- OpenAIGPTConfig configuration class: TFOpenAIGPTForSequenceClassification (OpenAI GPT model)
- RemBertConfig configuration class: TFRemBertForSequenceClassification (RemBERT model)
- RoFormerConfig configuration class: TFRoFormerForSequenceClassification (RoFormer model)
- RobertaConfig configuration class: TFRobertaForSequenceClassification (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm model)
- TapasConfig configuration class: TFTapasForSequenceClassification (TAPAS model)
- TransfoXLConfig configuration class: TFTransfoXLForSequenceClassification (Transformer-XL model)
- XLMConfig configuration class: TFXLMForSequenceClassification (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaForSequenceClassification (XLM-RoBERTa model)
- XLNetConfig configuration class: TFXLNetForSequenceClassification (XLNet model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有序列分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从PyTorch检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求时使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有序列分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — TFAlbertForSequenceClassification (ALBERT 模型)
- bart — TFBartForSequenceClassification (BART 模型)
- bert — TFBertForSequenceClassification (BERT 模型)
- camembert — TFCamembertForSequenceClassification (CamemBERT 模型)
- convbert — TFConvBertForSequenceClassification (ConvBERT 模型)
- ctrl — TFCTRLForSequenceClassification (CTRL 模型)
- deberta — TFDebertaForSequenceClassification (DeBERTa 模型)
- deberta-v2 — TFDebertaV2ForSequenceClassification (DeBERTa-v2 模型)
- distilbert — TFDistilBertForSequenceClassification (DistilBERT 模型)
- electra — TFElectraForSequenceClassification (ELECTRA 模型)
- esm — TFEsmForSequenceClassification (ESM 模型)
- flaubert — TFFlaubertForSequenceClassification (FlauBERT 模型)
- funnel — TFFunnelForSequenceClassification (漏斗变压器模型)
- gpt-sw3 — TFGPT2ForSequenceClassification (GPT-Sw3 模型)
- gpt2 — TFGPT2ForSequenceClassification (OpenAI GPT-2 模型)
- gptj — TFGPTJForSequenceClassification (GPT-J 模型)
- layoutlm — TFLayoutLMForSequenceClassification (LayoutLM 模型)
- layoutlmv3 — TFLayoutLMv3ForSequenceClassification (LayoutLMv3 模型)
- longformer — TFLongformerForSequenceClassification (Longformer 模型)
- mistral — TFMistralForSequenceClassification (Mistral 模型)
- mobilebert — TFMobileBertForSequenceClassification (MobileBERT 模型)
- mpnet — TFMPNetForSequenceClassification (MPNet 模型)
- openai-gpt — TFOpenAIGPTForSequenceClassification (OpenAI GPT 模型)
- rembert — TFRemBertForSequenceClassification (RemBERT 模型)
- roberta — TFRobertaForSequenceClassification (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForSequenceClassification (RoFormer 模型)
- tapas — TFTapasForSequenceClassification (TAPAS 模型)
- transfo-xl — TFTransfoXLForSequenceClassification (Transformer-XL 模型)
- xlm — TFXLMForSequenceClassification (XLM 模型)
- xlm-roberta — TFXLMRobertaForSequenceClassification (XLM-RoBERTa 模型)
- xlnet — TFXLNetForSequenceClassification (XLNet 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForSequenceClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSequenceClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有序列分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: FlaxAlbertForSequenceClassification (ALBERT model)
- BartConfig configuration class: FlaxBartForSequenceClassification (BART model)
- BertConfig configuration class: FlaxBertForSequenceClassification (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdForSequenceClassification (BigBird model)
- DistilBertConfig configuration class: FlaxDistilBertForSequenceClassification (DistilBERT model)
- ElectraConfig configuration class: FlaxElectraForSequenceClassification (ELECTRA model)
- MBartConfig configuration class: FlaxMBartForSequenceClassification (mBART model)
- RoFormerConfig configuration class: FlaxRoFormerForSequenceClassification (RoFormer model)
- RobertaConfig configuration class: FlaxRobertaForSequenceClassification (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaForSequenceClassification (XLM-RoBERTa model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有序列分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, optional, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,因此revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有序列分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — FlaxAlbertForSequenceClassification (ALBERT 模型)
- bart — FlaxBartForSequenceClassification (BART 模型)
- bert — FlaxBertForSequenceClassification (BERT 模型)
- big_bird — FlaxBigBirdForSequenceClassification (BigBird 模型)
- distilbert — FlaxDistilBertForSequenceClassification (DistilBERT 模型)
- electra — FlaxElectraForSequenceClassification (ELECTRA 模型)
- mbart — FlaxMBartForSequenceClassification (mBART 模型)
- roberta — FlaxRobertaForSequenceClassification (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForSequenceClassification (RoFormer 模型)
- xlm-roberta — FlaxXLMRobertaForSequenceClassification (XLM-RoBERTa 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMultipleChoice
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(具有多项选择头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: AlbertForMultipleChoice (ALBERT model)
- BertConfig configuration class: BertForMultipleChoice (BERT model)
- BigBirdConfig configuration class: BigBirdForMultipleChoice (BigBird model)
- CamembertConfig configuration class: CamembertForMultipleChoice (CamemBERT model)
- CanineConfig configuration class: CanineForMultipleChoice (CANINE model)
- ConvBertConfig configuration class: ConvBertForMultipleChoice (ConvBERT model)
- Data2VecTextConfig configuration class: Data2VecTextForMultipleChoice (Data2VecText model)
- DebertaV2Config configuration class: DebertaV2ForMultipleChoice (DeBERTa-v2 model)
- DistilBertConfig configuration class: DistilBertForMultipleChoice (DistilBERT model)
- ElectraConfig configuration class: ElectraForMultipleChoice (ELECTRA model)
- ErnieConfig configuration class: ErnieForMultipleChoice (ERNIE model)
- ErnieMConfig configuration class: ErnieMForMultipleChoice (ErnieM model)
- FNetConfig configuration class: FNetForMultipleChoice (FNet model)
- FlaubertConfig configuration class: FlaubertForMultipleChoice (FlauBERT model)
- FunnelConfig configuration class: FunnelForMultipleChoice (Funnel Transformer model)
- IBertConfig configuration class: IBertForMultipleChoice (I-BERT model)
- LongformerConfig configuration class: LongformerForMultipleChoice (Longformer model)
- LukeConfig configuration class: LukeForMultipleChoice (LUKE model)
- MPNetConfig configuration class: MPNetForMultipleChoice (MPNet model)
- MegaConfig configuration class: MegaForMultipleChoice (MEGA model)
- MegatronBertConfig configuration class: MegatronBertForMultipleChoice (Megatron-BERT model)
- MobileBertConfig configuration class: MobileBertForMultipleChoice (MobileBERT model)
- MraConfig configuration class: MraForMultipleChoice (MRA model)
- NezhaConfig configuration class: NezhaForMultipleChoice (Nezha model)
- NystromformerConfig configuration class: NystromformerForMultipleChoice (Nyströmformer model)
- QDQBertConfig configuration class: QDQBertForMultipleChoice (QDQBert model)
- RemBertConfig configuration class: RemBertForMultipleChoice (RemBERT model)
- RoCBertConfig configuration class: RoCBertForMultipleChoice (RoCBert model)
- RoFormerConfig configuration class: RoFormerForMultipleChoice (RoFormer model)
- RobertaConfig configuration class: RobertaForMultipleChoice (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm model)
- SqueezeBertConfig configuration class: SqueezeBertForMultipleChoice (SqueezeBERT model)
- XLMConfig configuration class: XLMForMultipleChoice (XLM model)
- XLMRobertaConfig configuration class: XLMRobertaForMultipleChoice (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLForMultipleChoice (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetForMultipleChoice (XLNet model)
- XmodConfig configuration class: XmodForMultipleChoice (X-MOD model)
- YosoConfig configuration class: YosoForMultipleChoice (YOSO model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有多项选择头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有多项选择头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — AlbertForMultipleChoice (ALBERT 模型)
- bert — BertForMultipleChoice (BERT 模型)
- big_bird — BigBirdForMultipleChoice (BigBird 模型)
- camembert — CamembertForMultipleChoice (CamemBERT 模型)
- canine — CanineForMultipleChoice (CANINE 模型)
- convbert — ConvBertForMultipleChoice (ConvBERT 模型)
- data2vec-text — Data2VecTextForMultipleChoice (Data2VecText 模型)
- deberta-v2 — DebertaV2ForMultipleChoice (DeBERTa-v2 模型)
- distilbert — DistilBertForMultipleChoice (DistilBERT 模型)
- electra — ElectraForMultipleChoice (ELECTRA 模型)
- ernie — ErnieForMultipleChoice (ERNIE 模型)
- ernie_m — ErnieMForMultipleChoice (ErnieM 模型)
- flaubert — FlaubertForMultipleChoice (FlauBERT 模型)
- fnet — FNetForMultipleChoice (FNet 模型)
- funnel — FunnelForMultipleChoice (漏斗变换器模型)
- ibert — IBertForMultipleChoice (I-BERT 模型)
- longformer — LongformerForMultipleChoice (Longformer 模型)
- luke — LukeForMultipleChoice (LUKE 模型)
- mega — MegaForMultipleChoice (MEGA 模型)
- megatron-bert — MegatronBertForMultipleChoice (Megatron-BERT 模型)
- mobilebert — MobileBertForMultipleChoice (MobileBERT 模型)
- mpnet — MPNetForMultipleChoice (MPNet 模型)
- mra — MraForMultipleChoice (MRA 模型)
- nezha — NezhaForMultipleChoice (Nezha 模型)
- nystromformer — NystromformerForMultipleChoice (Nyströmformer 模型)
- qdqbert — QDQBertForMultipleChoice (QDQBert 模型)
- rembert — RemBertForMultipleChoice (RemBERT 模型)
- roberta — RobertaForMultipleChoice (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForMultipleChoice (RoCBert 模型)
- roformer — RoFormerForMultipleChoice (RoFormer 模型)
- squeezebert — SqueezeBertForMultipleChoice (SqueezeBERT 模型)
- xlm — XLMForMultipleChoice (XLM 模型)
- xlm-roberta — XLMRobertaForMultipleChoice (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForMultipleChoice (XLM-RoBERTa-XL 模型)
- xlnet — XLNetForMultipleChoice (XLNet 模型)
- xmod — XmodForMultipleChoice (X-MOD 模型)
- yoso — YosoForMultipleChoice (YOSO 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForMultipleChoice
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMultipleChoice.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMultipleChoice
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(具有多项选择头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: TFAlbertForMultipleChoice (ALBERT model)
- BertConfig configuration class: TFBertForMultipleChoice (BERT model)
- CamembertConfig configuration class: TFCamembertForMultipleChoice (CamemBERT model)
- ConvBertConfig configuration class: TFConvBertForMultipleChoice (ConvBERT model)
- DebertaV2Config configuration class: TFDebertaV2ForMultipleChoice (DeBERTa-v2 model)
- DistilBertConfig configuration class: TFDistilBertForMultipleChoice (DistilBERT model)
- ElectraConfig configuration class: TFElectraForMultipleChoice (ELECTRA model)
- FlaubertConfig configuration class: TFFlaubertForMultipleChoice (FlauBERT model)
- FunnelConfig configuration class: TFFunnelForMultipleChoice (Funnel Transformer model)
- LongformerConfig configuration class: TFLongformerForMultipleChoice (Longformer model)
- MPNetConfig configuration class: TFMPNetForMultipleChoice (MPNet model)
- MobileBertConfig configuration class: TFMobileBertForMultipleChoice (MobileBERT model)
- RemBertConfig configuration class: TFRemBertForMultipleChoice (RemBERT model)
- RoFormerConfig configuration class: TFRoFormerForMultipleChoice (RoFormer model)
- RobertaConfig configuration class: TFRobertaForMultipleChoice (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm model)
- XLMConfig configuration class: TFXLMForMultipleChoice (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaForMultipleChoice (XLM-RoBERTa model)
- XLNetConfig configuration class: TFXLNetForMultipleChoice (XLNet model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有多项选择头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有多项选择头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — TFAlbertForMultipleChoice (ALBERT 模型)
- bert — TFBertForMultipleChoice (BERT 模型)
- camembert — TFCamembertForMultipleChoice (CamemBERT 模型)
- convbert — TFConvBertForMultipleChoice (ConvBERT 模型)
- deberta-v2 — TFDebertaV2ForMultipleChoice (DeBERTa-v2 模型)
- distilbert — TFDistilBertForMultipleChoice (DistilBERT 模型)
- electra — TFElectraForMultipleChoice (ELECTRA 模型)
- flaubert — TFFlaubertForMultipleChoice (FlauBERT 模型)
- funnel — TFFunnelForMultipleChoice (漏斗变换器模型)
- longformer — TFLongformerForMultipleChoice (Longformer 模型)
- mobilebert — TFMobileBertForMultipleChoice (MobileBERT 模型)
- mpnet — TFMPNetForMultipleChoice (MPNet 模型)
- rembert — TFRemBertForMultipleChoice (RemBERT 模型)
- roberta — TFRobertaForMultipleChoice (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForMultipleChoice (RoFormer 模型)
- xlm — TFXLMForMultipleChoice (XLM 模型)
- xlm-roberta — TFXLMRobertaForMultipleChoice (XLM-RoBERTa 模型)
- xlnet — TFXLNetForMultipleChoice (XLNet 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForMultipleChoice
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForMultipleChoice.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForMultipleChoice
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(具有多项选择头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: FlaxAlbertForMultipleChoice (ALBERT model)
- BertConfig configuration class: FlaxBertForMultipleChoice (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdForMultipleChoice (BigBird model)
- DistilBertConfig configuration class: FlaxDistilBertForMultipleChoice (DistilBERT model)
- ElectraConfig configuration class: FlaxElectraForMultipleChoice (ELECTRA model)
- RoFormerConfig configuration class: FlaxRoFormerForMultipleChoice (RoFormer model)
- RobertaConfig configuration class: FlaxRobertaForMultipleChoice (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaForMultipleChoice (XLM-RoBERTa model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有多项选择头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从PyTorch检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有多项选择头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — FlaxAlbertForMultipleChoice (ALBERT 模型)
- bert — FlaxBertForMultipleChoice (BERT 模型)
- big_bird — FlaxBigBirdForMultipleChoice (BigBird 模型)
- distilbert — FlaxDistilBertForMultipleChoice (DistilBERT 模型)
- electra — FlaxElectraForMultipleChoice (ELECTRA 模型)
- roberta — FlaxRobertaForMultipleChoice (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForMultipleChoice (RoFormer 模型)
- xlm-roberta — FlaxXLMRobertaForMultipleChoice (XLM-RoBERTa 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForMultipleChoice
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForMultipleChoice.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForNextSentencePrediction
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有下一个句子预测头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BertConfig configuration class: BertForNextSentencePrediction (BERT model)
- ErnieConfig configuration class: ErnieForNextSentencePrediction (ERNIE model)
- FNetConfig configuration class: FNetForNextSentencePrediction (FNet model)
- MegatronBertConfig configuration class: MegatronBertForNextSentencePrediction (Megatron-BERT model)
- MobileBertConfig configuration class: MobileBertForNextSentencePrediction (MobileBERT model)
- NezhaConfig configuration class: NezhaForNextSentencePrediction (Nezha model)
- QDQBertConfig configuration class: QDQBertForNextSentencePrediction (QDQBert model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有下一个句子预测头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的存储库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有下一个句子预测头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- bert — BertForNextSentencePrediction (BERT 模型)
- ernie — ErnieForNextSentencePrediction (ERNIE 模型)
- fnet — FNetForNextSentencePrediction (FNet 模型)
- megatron-bert — MegatronBertForNextSentencePrediction (Megatron-BERT 模型)
- mobilebert — MobileBertForNextSentencePrediction (MobileBERT 模型)
- nezha — NezhaForNextSentencePrediction (Nezha 模型)
- qdqbert — QDQBertForNextSentencePrediction (QDQBert 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForNextSentencePrediction
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForNextSentencePrediction.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForNextSentencePrediction
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有下一个句子预测头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- BertConfig 配置类:TFBertForNextSentencePrediction (BERT 模型)
- MobileBertConfig 配置类:TFMobileBertForNextSentencePrediction (MobileBERT 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有下一个句子预测头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每次请求时都会被使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有下一个句子预测头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- bert — TFBertForNextSentencePrediction (BERT 模型)
- mobilebert — TFMobileBertForNextSentencePrediction (MobileBERT 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForNextSentencePrediction
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有下一个句子预测头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
根据配置类选择要实例化的模型类:
- BertConfig 配置类: FlaxBertForNextSentencePrediction (BERT 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有下一个句子预测头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每次请求时都会被使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有下一个句子预测头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- bert — FlaxBertForNextSentencePrediction (BERT 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForNextSentencePrediction
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForNextSentencePrediction.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForTokenClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有标记分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: AlbertForTokenClassification (ALBERT model)
- BertConfig configuration class: BertForTokenClassification (BERT model)
- BigBirdConfig configuration class: BigBirdForTokenClassification (BigBird model)
- BioGptConfig configuration class: BioGptForTokenClassification (BioGpt model)
- BloomConfig configuration class: BloomForTokenClassification (BLOOM model)
- BrosConfig configuration class: BrosForTokenClassification (BROS model)
- CamembertConfig configuration class: CamembertForTokenClassification (CamemBERT model)
- CanineConfig configuration class: CanineForTokenClassification (CANINE model)
- ConvBertConfig configuration class: ConvBertForTokenClassification (ConvBERT model)
- Data2VecTextConfig configuration class: Data2VecTextForTokenClassification (Data2VecText model)
- DebertaConfig configuration class: DebertaForTokenClassification (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForTokenClassification (DeBERTa-v2 model)
- DistilBertConfig configuration class: DistilBertForTokenClassification (DistilBERT model)
- ElectraConfig configuration class: ElectraForTokenClassification (ELECTRA model)
- ErnieConfig configuration class: ErnieForTokenClassification (ERNIE model)
- ErnieMConfig configuration class: ErnieMForTokenClassification (ErnieM model)
- EsmConfig configuration class: EsmForTokenClassification (ESM model)
- FNetConfig configuration class: FNetForTokenClassification (FNet model)
- FalconConfig configuration class: FalconForTokenClassification (Falcon model)
- FlaubertConfig configuration class: FlaubertForTokenClassification (FlauBERT model)
- FunnelConfig configuration class: FunnelForTokenClassification (Funnel Transformer model)
- GPT2Config configuration class: GPT2ForTokenClassification (OpenAI GPT-2 model)
- GPTBigCodeConfig configuration class: GPTBigCodeForTokenClassification (GPTBigCode model)
- GPTNeoConfig configuration class: GPTNeoForTokenClassification (GPT Neo model)
- GPTNeoXConfig configuration class: GPTNeoXForTokenClassification (GPT NeoX model)
- Gemma2Config configuration class: Gemma2ForTokenClassification (Gemma2 model)
- GemmaConfig configuration class: GemmaForTokenClassification (Gemma model)
- GlmConfig configuration class: GlmForTokenClassification (GLM model)
- IBertConfig configuration class: IBertForTokenClassification (I-BERT model)
- LayoutLMConfig configuration class: LayoutLMForTokenClassification (LayoutLM model)
- LayoutLMv2Config configuration class: LayoutLMv2ForTokenClassification (LayoutLMv2 model)
- LayoutLMv3Config configuration class: LayoutLMv3ForTokenClassification (LayoutLMv3 model)
- LiltConfig configuration class: LiltForTokenClassification (LiLT model)
- LlamaConfig configuration class: LlamaForTokenClassification (LLaMA model)
- LongformerConfig configuration class: LongformerForTokenClassification (Longformer model)
- LukeConfig configuration class: LukeForTokenClassification (LUKE model)
- MPNetConfig configuration class: MPNetForTokenClassification (MPNet model)
- MT5Config configuration class: MT5ForTokenClassification (MT5 model)
- MarkupLMConfig configuration class: MarkupLMForTokenClassification (MarkupLM model)
- MegaConfig configuration class: MegaForTokenClassification (MEGA model)
- MegatronBertConfig configuration class: MegatronBertForTokenClassification (Megatron-BERT model)
- MistralConfig configuration class: MistralForTokenClassification (Mistral model)
- MixtralConfig configuration class: MixtralForTokenClassification (Mixtral model)
- MobileBertConfig configuration class: MobileBertForTokenClassification (MobileBERT model)
- MptConfig configuration class: MptForTokenClassification (MPT model)
- MraConfig configuration class: MraForTokenClassification (MRA model)
- NemotronConfig configuration class: NemotronForTokenClassification (Nemotron model)
- NezhaConfig configuration class: NezhaForTokenClassification (Nezha model)
- NystromformerConfig configuration class: NystromformerForTokenClassification (Nyströmformer model)
- PersimmonConfig configuration class: PersimmonForTokenClassification (Persimmon model)
- Phi3Config configuration class: Phi3ForTokenClassification (Phi3 model)
- PhiConfig configuration class: PhiForTokenClassification (Phi model)
- QDQBertConfig configuration class: QDQBertForTokenClassification (QDQBert model)
- Qwen2Config configuration class: Qwen2ForTokenClassification (Qwen2 model)
- Qwen2MoeConfig configuration class: Qwen2MoeForTokenClassification (Qwen2MoE model)
- RemBertConfig configuration class: RemBertForTokenClassification (RemBERT model)
- RoCBertConfig configuration class: RoCBertForTokenClassification (RoCBert model)
- RoFormerConfig configuration class: RoFormerForTokenClassification (RoFormer model)
- RobertaConfig configuration class: RobertaForTokenClassification (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm model)
- SqueezeBertConfig configuration class: SqueezeBertForTokenClassification (SqueezeBERT model)
- StableLmConfig configuration class: StableLmForTokenClassification (StableLm model)
- Starcoder2Config configuration class: Starcoder2ForTokenClassification (Starcoder2 model)
- T5Config configuration class: T5ForTokenClassification (T5 model)
- UMT5Config configuration class: UMT5ForTokenClassification (UMT5 model)
- XLMConfig configuration class: XLMForTokenClassification (XLM model)
- XLMRobertaConfig configuration class: XLMRobertaForTokenClassification (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLForTokenClassification (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetForTokenClassification (XLNet model)
- XmodConfig configuration class: XmodForTokenClassification (X-MOD model)
- YosoConfig configuration class: YosoForTokenClassification (YOSO model)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置实例化库中的一个模型类(带有标记分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有标记分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — AlbertForTokenClassification (ALBERT 模型)
- bert — BertForTokenClassification (BERT 模型)
- big_bird — BigBirdForTokenClassification (BigBird 模型)
- biogpt — BioGptForTokenClassification (BioGpt 模型)
- bloom — BloomForTokenClassification (BLOOM 模型)
- bros — BrosForTokenClassification (BROS 模型)
- camembert — CamembertForTokenClassification (CamemBERT 模型)
- 犬类 — CanineForTokenClassification (CANINE 模型)
- convbert — ConvBertForTokenClassification (ConvBERT 模型)
- data2vec-text — Data2VecTextForTokenClassification (Data2VecText 模型)
- deberta — DebertaForTokenClassification (DeBERTa 模型)
- deberta-v2 — DebertaV2ForTokenClassification (DeBERTa-v2 模型)
- distilbert — DistilBertForTokenClassification (DistilBERT 模型)
- electra — ElectraForTokenClassification (ELECTRA 模型)
- ernie — ErnieForTokenClassification (ERNIE 模型)
- ernie_m — ErnieMForTokenClassification (ErnieM 模型)
- esm — EsmForTokenClassification (ESM 模型)
- falcon — FalconForTokenClassification (Falcon 模型)
- flaubert — FlaubertForTokenClassification (FlauBERT 模型)
- fnet — FNetForTokenClassification (FNet 模型)
- funnel — FunnelForTokenClassification (漏斗变压器模型)
- gemma — GemmaForTokenClassification (Gemma 模型)
- gemma2 — Gemma2ForTokenClassification (Gemma2 模型)
- glm — GlmForTokenClassification (GLM 模型)
- gpt-sw3 — GPT2ForTokenClassification (GPT-Sw3 模型)
- gpt2 — GPT2ForTokenClassification (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeForTokenClassification (GPTBigCode 模型)
- gpt_neo — GPTNeoForTokenClassification (GPT Neo 模型)
- gpt_neox — GPTNeoXForTokenClassification (GPT NeoX 模型)
- ibert — IBertForTokenClassification (I-BERT 模型)
- layoutlm — LayoutLMForTokenClassification (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2ForTokenClassification (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3ForTokenClassification (LayoutLMv3 模型)
- lilt — LiltForTokenClassification (LiLT 模型)
- llama — LlamaForTokenClassification (LLaMA 模型)
- longformer — LongformerForTokenClassification (Longformer 模型)
- luke — LukeForTokenClassification (LUKE 模型)
- markuplm — MarkupLMForTokenClassification (MarkupLM 模型)
- mega — MegaForTokenClassification (MEGA 模型)
- megatron-bert — MegatronBertForTokenClassification (Megatron-BERT 模型)
- mistral — MistralForTokenClassification (Mistral 模型)
- mixtral — MixtralForTokenClassification (Mixtral 模型)
- mobilebert — MobileBertForTokenClassification (MobileBERT 模型)
- mpnet — MPNetForTokenClassification (MPNet 模型)
- mpt — MptForTokenClassification (MPT 模型)
- mra — MraForTokenClassification (MRA 模型)
- mt5 — MT5ForTokenClassification (MT5 模型)
- nemotron — NemotronForTokenClassification (Nemotron 模型)
- nezha — NezhaForTokenClassification (Nezha 模型)
- nystromformer — NystromformerForTokenClassification (Nyströmformer 模型)
- 柿子 — PersimmonForTokenClassification (柿子模型)
- phi — PhiForTokenClassification (Phi 模型)
- phi3 — Phi3ForTokenClassification (Phi3 模型)
- qdqbert — QDQBertForTokenClassification (QDQBert 模型)
- qwen2 — Qwen2ForTokenClassification (Qwen2 模型)
- qwen2_moe — Qwen2MoeForTokenClassification (Qwen2MoE 模型)
- rembert — RemBertForTokenClassification (RemBERT 模型)
- roberta — RobertaForTokenClassification (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForTokenClassification (RoCBert 模型)
- roformer — RoFormerForTokenClassification (RoFormer 模型)
- squeezebert — SqueezeBertForTokenClassification (SqueezeBERT 模型)
- stablelm — StableLmForTokenClassification (StableLm 模型)
- starcoder2 — Starcoder2ForTokenClassification (Starcoder2 模型)
- t5 — T5ForTokenClassification (T5 模型)
- umt5 — UMT5ForTokenClassification (UMT5 模型)
- xlm — XLMForTokenClassification (XLM 模型)
- xlm-roberta — XLMRobertaForTokenClassification (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForTokenClassification (XLM-RoBERTa-XL 模型)
- xlnet — XLNetForTokenClassification (XLNet 模型)
- xmod — XmodForTokenClassification (X-MOD 模型)
- yoso — YosoForTokenClassification (YOSO 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForTokenClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForTokenClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForTokenClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有标记分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: TFAlbertForTokenClassification (ALBERT model)
- BertConfig configuration class: TFBertForTokenClassification (BERT model)
- CamembertConfig configuration class: TFCamembertForTokenClassification (CamemBERT model)
- ConvBertConfig configuration class: TFConvBertForTokenClassification (ConvBERT model)
- DebertaConfig configuration class: TFDebertaForTokenClassification (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2ForTokenClassification (DeBERTa-v2 model)
- DistilBertConfig configuration class: TFDistilBertForTokenClassification (DistilBERT model)
- ElectraConfig configuration class: TFElectraForTokenClassification (ELECTRA model)
- EsmConfig configuration class: TFEsmForTokenClassification (ESM model)
- FlaubertConfig configuration class: TFFlaubertForTokenClassification (FlauBERT model)
- FunnelConfig configuration class: TFFunnelForTokenClassification (Funnel Transformer model)
- LayoutLMConfig configuration class: TFLayoutLMForTokenClassification (LayoutLM model)
- LayoutLMv3Config configuration class: TFLayoutLMv3ForTokenClassification (LayoutLMv3 model)
- LongformerConfig configuration class: TFLongformerForTokenClassification (Longformer model)
- MPNetConfig configuration class: TFMPNetForTokenClassification (MPNet model)
- MobileBertConfig configuration class: TFMobileBertForTokenClassification (MobileBERT model)
- RemBertConfig configuration class: TFRemBertForTokenClassification (RemBERT model)
- RoFormerConfig configuration class: TFRoFormerForTokenClassification (RoFormer model)
- RobertaConfig configuration class: TFRobertaForTokenClassification (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm model)
- XLMConfig configuration class: TFXLMForTokenClassification (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaForTokenClassification (XLM-RoBERTa model)
- XLNetConfig configuration class: TFXLNetForTokenClassification (XLNet model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有标记分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有标记分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — TFAlbertForTokenClassification (ALBERT 模型)
- bert — TFBertForTokenClassification (BERT 模型)
- camembert — TFCamembertForTokenClassification (CamemBERT 模型)
- convbert — TFConvBertForTokenClassification (ConvBERT 模型)
- deberta — TFDebertaForTokenClassification (DeBERTa 模型)
- deberta-v2 — TFDebertaV2ForTokenClassification (DeBERTa-v2 模型)
- distilbert — TFDistilBertForTokenClassification (DistilBERT 模型)
- electra — TFElectraForTokenClassification (ELECTRA 模型)
- esm — TFEsmForTokenClassification (ESM 模型)
- flaubert — TFFlaubertForTokenClassification (FlauBERT 模型)
- funnel — TFFunnelForTokenClassification (漏斗变压器模型)
- layoutlm — TFLayoutLMForTokenClassification (LayoutLM 模型)
- layoutlmv3 — TFLayoutLMv3ForTokenClassification (LayoutLMv3 模型)
- longformer — TFLongformerForTokenClassification (Longformer 模型)
- mobilebert — TFMobileBertForTokenClassification (MobileBERT 模型)
- mpnet — TFMPNetForTokenClassification (MPNet 模型)
- rembert — TFRemBertForTokenClassification (RemBERT 模型)
- roberta — TFRobertaForTokenClassification (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForTokenClassification (RoFormer 模型)
- xlm — TFXLMForTokenClassification (XLM 模型)
- xlm-roberta — TFXLMRobertaForTokenClassification (XLM-RoBERTa 模型)
- xlnet — TFXLNetForTokenClassification (XLNet 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForTokenClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForTokenClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForTokenClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有标记分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: FlaxAlbertForTokenClassification (ALBERT model)
- BertConfig configuration class: FlaxBertForTokenClassification (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdForTokenClassification (BigBird model)
- DistilBertConfig configuration class: FlaxDistilBertForTokenClassification (DistilBERT model)
- ElectraConfig configuration class: FlaxElectraForTokenClassification (ELECTRA model)
- RoFormerConfig configuration class: FlaxRoFormerForTokenClassification (RoFormer model)
- RobertaConfig configuration class: FlaxRobertaForTokenClassification (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaForTokenClassification (XLM-RoBERTa model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有标记分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从PyTorch检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, 默认为"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有标记分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — FlaxAlbertForTokenClassification (ALBERT 模型)
- bert — FlaxBertForTokenClassification (BERT 模型)
- big_bird — FlaxBigBirdForTokenClassification (BigBird 模型)
- distilbert — FlaxDistilBertForTokenClassification (DistilBERT 模型)
- electra — FlaxElectraForTokenClassification (ELECTRA 模型)
- roberta — FlaxRobertaForTokenClassification (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForTokenClassification (RoFormer 模型)
- xlm-roberta — FlaxXLMRobertaForTokenClassification (XLM-RoBERTa 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForTokenClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForTokenClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForQuestionAnswering
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有问答头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: AlbertForQuestionAnswering (ALBERT model)
- BartConfig configuration class: BartForQuestionAnswering (BART model)
- BertConfig configuration class: BertForQuestionAnswering (BERT model)
- BigBirdConfig configuration class: BigBirdForQuestionAnswering (BigBird model)
- BigBirdPegasusConfig configuration class: BigBirdPegasusForQuestionAnswering (BigBird-Pegasus model)
- BloomConfig configuration class: BloomForQuestionAnswering (BLOOM model)
- CamembertConfig configuration class: CamembertForQuestionAnswering (CamemBERT model)
- CanineConfig configuration class: CanineForQuestionAnswering (CANINE model)
- ConvBertConfig configuration class: ConvBertForQuestionAnswering (ConvBERT model)
- Data2VecTextConfig configuration class: Data2VecTextForQuestionAnswering (Data2VecText model)
- DebertaConfig configuration class: DebertaForQuestionAnswering (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForQuestionAnswering (DeBERTa-v2 model)
- DistilBertConfig configuration class: DistilBertForQuestionAnswering (DistilBERT model)
- ElectraConfig configuration class: ElectraForQuestionAnswering (ELECTRA model)
- ErnieConfig configuration class: ErnieForQuestionAnswering (ERNIE model)
- ErnieMConfig configuration class: ErnieMForQuestionAnswering (ErnieM model)
- FNetConfig configuration class: FNetForQuestionAnswering (FNet model)
- FalconConfig configuration class: FalconForQuestionAnswering (Falcon model)
- FlaubertConfig configuration class: FlaubertForQuestionAnsweringSimple (FlauBERT model)
- FunnelConfig configuration class: FunnelForQuestionAnswering (Funnel Transformer model)
- GPT2Config configuration class: GPT2ForQuestionAnswering (OpenAI GPT-2 model)
- GPTJConfig configuration class: GPTJForQuestionAnswering (GPT-J model)
- GPTNeoConfig configuration class: GPTNeoForQuestionAnswering (GPT Neo model)
- GPTNeoXConfig configuration class: GPTNeoXForQuestionAnswering (GPT NeoX model)
- IBertConfig configuration class: IBertForQuestionAnswering (I-BERT model)
- LEDConfig configuration class: LEDForQuestionAnswering (LED model)
- LayoutLMv2Config configuration class: LayoutLMv2ForQuestionAnswering (LayoutLMv2 model)
- LayoutLMv3Config configuration class: LayoutLMv3ForQuestionAnswering (LayoutLMv3 model)
- LiltConfig configuration class: LiltForQuestionAnswering (LiLT model)
- LlamaConfig configuration class: LlamaForQuestionAnswering (LLaMA model)
- LongformerConfig configuration class: LongformerForQuestionAnswering (Longformer model)
- LukeConfig configuration class: LukeForQuestionAnswering (LUKE model)
- LxmertConfig configuration class: LxmertForQuestionAnswering (LXMERT model)
- MBartConfig configuration class: MBartForQuestionAnswering (mBART model)
- MPNetConfig configuration class: MPNetForQuestionAnswering (MPNet model)
- MT5Config configuration class: MT5ForQuestionAnswering (MT5 model)
- MarkupLMConfig configuration class: MarkupLMForQuestionAnswering (MarkupLM model)
- MegaConfig configuration class: MegaForQuestionAnswering (MEGA model)
- MegatronBertConfig configuration class: MegatronBertForQuestionAnswering (Megatron-BERT model)
- MistralConfig configuration class: MistralForQuestionAnswering (Mistral model)
- MixtralConfig configuration class: MixtralForQuestionAnswering (Mixtral model)
- MobileBertConfig configuration class: MobileBertForQuestionAnswering (MobileBERT model)
- MptConfig configuration class: MptForQuestionAnswering (MPT model)
- MraConfig configuration class: MraForQuestionAnswering (MRA model)
- MvpConfig configuration class: MvpForQuestionAnswering (MVP model)
- NemotronConfig configuration class: NemotronForQuestionAnswering (Nemotron model)
- NezhaConfig configuration class: NezhaForQuestionAnswering (Nezha model)
- NystromformerConfig configuration class: NystromformerForQuestionAnswering (Nyströmformer model)
- OPTConfig configuration class: OPTForQuestionAnswering (OPT model)
- QDQBertConfig configuration class: QDQBertForQuestionAnswering (QDQBert model)
- Qwen2Config configuration class: Qwen2ForQuestionAnswering (Qwen2 model)
- Qwen2MoeConfig configuration class: Qwen2MoeForQuestionAnswering (Qwen2MoE model)
- ReformerConfig configuration class: ReformerForQuestionAnswering (Reformer model)
- RemBertConfig configuration class: RemBertForQuestionAnswering (RemBERT model)
- RoCBertConfig configuration class: RoCBertForQuestionAnswering (RoCBert model)
- RoFormerConfig configuration class: RoFormerForQuestionAnswering (RoFormer model)
- RobertaConfig configuration class: RobertaForQuestionAnswering (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm model)
- SplinterConfig configuration class: SplinterForQuestionAnswering (Splinter model)
- SqueezeBertConfig configuration class: SqueezeBertForQuestionAnswering (SqueezeBERT model)
- T5Config configuration class: T5ForQuestionAnswering (T5 model)
- UMT5Config configuration class: UMT5ForQuestionAnswering (UMT5 model)
- XLMConfig configuration class: XLMForQuestionAnsweringSimple (XLM model)
- XLMRobertaConfig configuration class: XLMRobertaForQuestionAnswering (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLForQuestionAnswering (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetForQuestionAnsweringSimple (XLNet model)
- XmodConfig configuration class: XmodForQuestionAnswering (X-MOD model)
- YosoConfig configuration class: YosoForQuestionAnswering (YOSO model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有问答头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有问答头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — AlbertForQuestionAnswering (ALBERT 模型)
- bart — BartForQuestionAnswering (BART 模型)
- bert — BertForQuestionAnswering (BERT 模型)
- big_bird — BigBirdForQuestionAnswering (BigBird 模型)
- bigbird_pegasus — BigBirdPegasusForQuestionAnswering (BigBird-Pegasus 模型)
- bloom — BloomForQuestionAnswering (BLOOM 模型)
- camembert — CamembertForQuestionAnswering (CamemBERT 模型)
- 犬类 — CanineForQuestionAnswering (CANINE 模型)
- convbert — ConvBertForQuestionAnswering (ConvBERT 模型)
- data2vec-text — Data2VecTextForQuestionAnswering (Data2VecText 模型)
- deberta — DebertaForQuestionAnswering (DeBERTa 模型)
- deberta-v2 — DebertaV2ForQuestionAnswering (DeBERTa-v2 模型)
- distilbert — DistilBertForQuestionAnswering (DistilBERT 模型)
- electra — ElectraForQuestionAnswering (ELECTRA 模型)
- ernie — ErnieForQuestionAnswering (ERNIE 模型)
- ernie_m — ErnieMForQuestionAnswering (ErnieM 模型)
- falcon — FalconForQuestionAnswering (Falcon 模型)
- flaubert — FlaubertForQuestionAnsweringSimple (FlauBERT 模型)
- fnet — FNetForQuestionAnswering (FNet 模型)
- funnel — FunnelForQuestionAnswering (漏斗变压器模型)
- gpt2 — GPT2ForQuestionAnswering (OpenAI GPT-2 模型)
- gpt_neo — GPTNeoForQuestionAnswering (GPT Neo 模型)
- gpt_neox — GPTNeoXForQuestionAnswering (GPT NeoX 模型)
- gptj — GPTJForQuestionAnswering (GPT-J 模型)
- ibert — IBertForQuestionAnswering (I-BERT 模型)
- layoutlmv2 — LayoutLMv2ForQuestionAnswering (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
- led — LEDForQuestionAnswering (LED 模型)
- lilt — LiltForQuestionAnswering (LiLT 模型)
- llama — LlamaForQuestionAnswering (LLaMA 模型)
- longformer — LongformerForQuestionAnswering (Longformer 模型)
- luke — LukeForQuestionAnswering (LUKE 模型)
- lxmert — LxmertForQuestionAnswering (LXMERT 模型)
- markuplm — MarkupLMForQuestionAnswering (MarkupLM 模型)
- mbart — MBartForQuestionAnswering (mBART 模型)
- mega — MegaForQuestionAnswering (MEGA 模型)
- megatron-bert — MegatronBertForQuestionAnswering (Megatron-BERT 模型)
- mistral — MistralForQuestionAnswering (Mistral 模型)
- mixtral — MixtralForQuestionAnswering (Mixtral 模型)
- mobilebert — MobileBertForQuestionAnswering (MobileBERT 模型)
- mpnet — MPNetForQuestionAnswering (MPNet 模型)
- mpt — MptForQuestionAnswering (MPT 模型)
- mra — MraForQuestionAnswering (MRA 模型)
- mt5 — MT5ForQuestionAnswering (MT5 模型)
- mvp — MvpForQuestionAnswering (MVP 模型)
- nemotron — NemotronForQuestionAnswering (Nemotron 模型)
- nezha — NezhaForQuestionAnswering (Nezha 模型)
- nystromformer — NystromformerForQuestionAnswering (Nyströmformer 模型)
- opt — OPTForQuestionAnswering (OPT 模型)
- qdqbert — QDQBertForQuestionAnswering (QDQBert 模型)
- qwen2 — Qwen2ForQuestionAnswering (Qwen2 模型)
- qwen2_moe — Qwen2MoeForQuestionAnswering (Qwen2MoE 模型)
- reformer — ReformerForQuestionAnswering (Reformer 模型)
- rembert — RemBertForQuestionAnswering (RemBERT 模型)
- roberta — RobertaForQuestionAnswering (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForQuestionAnswering (RoCBert 模型)
- roformer — RoFormerForQuestionAnswering (RoFormer 模型)
- splinter — SplinterForQuestionAnswering (Splinter 模型)
- squeezebert — SqueezeBertForQuestionAnswering (SqueezeBERT 模型)
- t5 — T5ForQuestionAnswering (T5 模型)
- umt5 — UMT5ForQuestionAnswering (UMT5 模型)
- xlm — XLMForQuestionAnsweringSimple (XLM 模型)
- xlm-roberta — XLMRobertaForQuestionAnswering (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForQuestionAnswering (XLM-RoBERTa-XL 模型)
- xlnet — XLNetForQuestionAnsweringSimple (XLNet 模型)
- xmod — XmodForQuestionAnswering (X-MOD 模型)
- yoso — YosoForQuestionAnswering (YOSO 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForQuestionAnswering.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForQuestionAnswering
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有问答头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: TFAlbertForQuestionAnswering (ALBERT model)
- BertConfig configuration class: TFBertForQuestionAnswering (BERT model)
- CamembertConfig configuration class: TFCamembertForQuestionAnswering (CamemBERT model)
- ConvBertConfig configuration class: TFConvBertForQuestionAnswering (ConvBERT model)
- DebertaConfig configuration class: TFDebertaForQuestionAnswering (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2ForQuestionAnswering (DeBERTa-v2 model)
- DistilBertConfig configuration class: TFDistilBertForQuestionAnswering (DistilBERT model)
- ElectraConfig configuration class: TFElectraForQuestionAnswering (ELECTRA model)
- FlaubertConfig configuration class: TFFlaubertForQuestionAnsweringSimple (FlauBERT model)
- FunnelConfig configuration class: TFFunnelForQuestionAnswering (Funnel Transformer model)
- GPTJConfig configuration class: TFGPTJForQuestionAnswering (GPT-J model)
- LayoutLMv3Config configuration class: TFLayoutLMv3ForQuestionAnswering (LayoutLMv3 model)
- LongformerConfig configuration class: TFLongformerForQuestionAnswering (Longformer model)
- MPNetConfig configuration class: TFMPNetForQuestionAnswering (MPNet model)
- MobileBertConfig configuration class: TFMobileBertForQuestionAnswering (MobileBERT model)
- RemBertConfig configuration class: TFRemBertForQuestionAnswering (RemBERT model)
- RoFormerConfig configuration class: TFRoFormerForQuestionAnswering (RoFormer model)
- RobertaConfig configuration class: TFRobertaForQuestionAnswering (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm model)
- XLMConfig configuration class: TFXLMForQuestionAnsweringSimple (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaForQuestionAnswering (XLM-RoBERTa model)
- XLNetConfig configuration class: TFXLNetForQuestionAnsweringSimple (XLNet model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有问答头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,因此revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有问答头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — TFAlbertForQuestionAnswering (ALBERT 模型)
- bert — TFBertForQuestionAnswering (BERT 模型)
- camembert — TFCamembertForQuestionAnswering (CamemBERT 模型)
- convbert — TFConvBertForQuestionAnswering (ConvBERT 模型)
- deberta — TFDebertaForQuestionAnswering (DeBERTa 模型)
- deberta-v2 — TFDebertaV2ForQuestionAnswering (DeBERTa-v2 模型)
- distilbert — TFDistilBertForQuestionAnswering (DistilBERT 模型)
- electra — TFElectraForQuestionAnswering (ELECTRA 模型)
- flaubert — TFFlaubertForQuestionAnsweringSimple (FlauBERT 模型)
- funnel — TFFunnelForQuestionAnswering (漏斗变压器模型)
- gptj — TFGPTJForQuestionAnswering (GPT-J 模型)
- layoutlmv3 — TFLayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
- longformer — TFLongformerForQuestionAnswering (Longformer 模型)
- mobilebert — TFMobileBertForQuestionAnswering (MobileBERT 模型)
- mpnet — TFMPNetForQuestionAnswering (MPNet 模型)
- rembert — TFRemBertForQuestionAnswering (RemBERT 模型)
- roberta — TFRobertaForQuestionAnswering (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForQuestionAnswering (RoFormer 模型)
- xlm — TFXLMForQuestionAnsweringSimple (XLM 模型)
- xlm-roberta — TFXLMRobertaForQuestionAnswering (XLM-RoBERTa 模型)
- xlnet — TFXLNetForQuestionAnsweringSimple (XLNet 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForQuestionAnswering.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForQuestionAnswering
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有问答头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: FlaxAlbertForQuestionAnswering (ALBERT model)
- BartConfig configuration class: FlaxBartForQuestionAnswering (BART model)
- BertConfig configuration class: FlaxBertForQuestionAnswering (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdForQuestionAnswering (BigBird model)
- DistilBertConfig configuration class: FlaxDistilBertForQuestionAnswering (DistilBERT model)
- ElectraConfig configuration class: FlaxElectraForQuestionAnswering (ELECTRA model)
- MBartConfig configuration class: FlaxMBartForQuestionAnswering (mBART model)
- RoFormerConfig configuration class: FlaxRoFormerForQuestionAnswering (RoFormer model)
- RobertaConfig configuration class: FlaxRobertaForQuestionAnswering (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaForQuestionAnswering (XLM-RoBERTa model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有问答头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载都会在可能的情况下恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每次请求时使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有问答头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- albert — FlaxAlbertForQuestionAnswering (ALBERT 模型)
- bart — FlaxBartForQuestionAnswering (BART 模型)
- bert — FlaxBertForQuestionAnswering (BERT 模型)
- big_bird — FlaxBigBirdForQuestionAnswering (BigBird 模型)
- distilbert — FlaxDistilBertForQuestionAnswering (DistilBERT 模型)
- electra — FlaxElectraForQuestionAnswering (ELECTRA 模型)
- mbart — FlaxMBartForQuestionAnswering (mBART 模型)
- roberta — FlaxRobertaForQuestionAnswering (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForQuestionAnswering (RoFormer 模型)
- xlm-roberta — FlaxXLMRobertaForQuestionAnswering (XLM-RoBERTa 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForQuestionAnswering.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForTextEncoding
TFAutoModelForTextEncoding
计算机视觉
以下自动类别适用于以下计算机视觉任务。
AutoModelForDepthEstimation
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有深度估计头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- DPTConfig configuration class: DPTForDepthEstimation (DPT model)
- DepthAnythingConfig configuration class: DepthAnythingForDepthEstimation (Depth Anything model)
- GLPNConfig configuration class: GLPNForDepthEstimation (GLPN model)
- ZoeDepthConfig configuration class: ZoeDepthForDepthEstimation (ZoeDepth model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有深度估计头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将被传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在 您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有深度估计头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- depth_anything — DepthAnythingForDepthEstimation(深度估计的Depth Anything模型)
- dpt — DPTForDepthEstimation (DPT 模型)
- glpn — GLPNForDepthEstimation (GLPN 模型)
- zoedepth — ZoeDepthForDepthEstimation (ZoeDepth 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForDepthEstimation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForDepthEstimation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForDepthEstimation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForDepthEstimation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForImageClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有图像分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BeitConfig configuration class: BeitForImageClassification (BEiT model)
- BitConfig configuration class: BitForImageClassification (BiT model)
- CLIPConfig configuration class: CLIPForImageClassification (CLIP model)
- ConvNextConfig configuration class: ConvNextForImageClassification (ConvNeXT model)
- ConvNextV2Config configuration class: ConvNextV2ForImageClassification (ConvNeXTV2 model)
- CvtConfig configuration class: CvtForImageClassification (CvT model)
- Data2VecVisionConfig configuration class: Data2VecVisionForImageClassification (Data2VecVision model)
- DeiTConfig configuration class: DeiTForImageClassification or DeiTForImageClassificationWithTeacher (DeiT model)
- DinatConfig configuration class: DinatForImageClassification (DiNAT model)
- Dinov2Config configuration class: Dinov2ForImageClassification (DINOv2 model)
- EfficientFormerConfig configuration class: EfficientFormerForImageClassification or EfficientFormerForImageClassificationWithTeacher (EfficientFormer model)
- EfficientNetConfig configuration class: EfficientNetForImageClassification (EfficientNet model)
- FocalNetConfig configuration class: FocalNetForImageClassification (FocalNet model)
- HieraConfig configuration class: HieraForImageClassification (Hiera model)
- IJepaConfig configuration class: IJepaForImageClassification (I-JEPA model)
- ImageGPTConfig configuration class: ImageGPTForImageClassification (ImageGPT model)
- LevitConfig configuration class: LevitForImageClassification or LevitForImageClassificationWithTeacher (LeViT model)
- MobileNetV1Config configuration class: MobileNetV1ForImageClassification (MobileNetV1 model)
- MobileNetV2Config configuration class: MobileNetV2ForImageClassification (MobileNetV2 model)
- MobileViTConfig configuration class: MobileViTForImageClassification (MobileViT model)
- MobileViTV2Config configuration class: MobileViTV2ForImageClassification (MobileViTV2 model)
- NatConfig configuration class: NatForImageClassification (NAT model)
- PerceiverConfig configuration class: PerceiverForImageClassificationLearned or PerceiverForImageClassificationFourier or PerceiverForImageClassificationConvProcessing (Perceiver model)
- PoolFormerConfig configuration class: PoolFormerForImageClassification (PoolFormer model)
- PvtConfig configuration class: PvtForImageClassification (PVT model)
- PvtV2Config configuration class: PvtV2ForImageClassification (PVTv2 model)
- RegNetConfig configuration class: RegNetForImageClassification (RegNet model)
- ResNetConfig configuration class: ResNetForImageClassification (ResNet model)
- SegformerConfig configuration class: SegformerForImageClassification (SegFormer model)
- SiglipConfig configuration class: SiglipForImageClassification (SigLIP model)
- SwiftFormerConfig configuration class: SwiftFormerForImageClassification (SwiftFormer model)
- SwinConfig configuration class: SwinForImageClassification (Swin Transformer model)
- Swinv2Config configuration class: Swinv2ForImageClassification (Swin Transformer V2 model)
- VanConfig configuration class: VanForImageClassification (VAN model)
- ViTConfig configuration class: ViTForImageClassification (ViT model)
- ViTHybridConfig configuration class: ViTHybridForImageClassification (ViT Hybrid model)
- ViTMSNConfig configuration class: ViTMSNForImageClassification (ViTMSN model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有图像分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载都会在可能的情况下恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有图像分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- beit — BeitForImageClassification (BEiT 模型)
- bit — BitForImageClassification (BiT 模型)
- clip — CLIPForImageClassification (CLIP 模型)
- convnext — ConvNextForImageClassification (ConvNeXT 模型)
- convnextv2 — ConvNextV2ForImageClassification (ConvNeXTV2 模型)
- cvt — CvtForImageClassification (CvT 模型)
- data2vec-vision — Data2VecVisionForImageClassification (Data2VecVision 模型)
- deit — DeiTForImageClassification 或 DeiTForImageClassificationWithTeacher (DeiT 模型)
- dinat — DinatForImageClassification (DiNAT 模型)
- dinov2 — Dinov2ForImageClassification (DINOv2 模型)
- efficientformer — EfficientFormerForImageClassification 或 EfficientFormerForImageClassificationWithTeacher (EfficientFormer 模型)
- efficientnet — EfficientNetForImageClassification (EfficientNet 模型)
- focalnet — FocalNetForImageClassification (FocalNet 模型)
- hiera — HieraForImageClassification (Hiera 模型)
- ijepa — IJepaForImageClassification (I-JEPA 模型)
- imagegpt — ImageGPTForImageClassification (ImageGPT 模型)
- levit — LevitForImageClassification 或 LevitForImageClassificationWithTeacher (LeViT 模型)
- mobilenet_v1 — MobileNetV1ForImageClassification (MobileNetV1 模型)
- mobilenet_v2 — MobileNetV2ForImageClassification (MobileNetV2 模型)
- mobilevit — MobileViTForImageClassification (MobileViT 模型)
- mobilevitv2 — MobileViTV2ForImageClassification (MobileViTV2 模型)
- nat — NatForImageClassification (NAT 模型)
- 感知器 — PerceiverForImageClassificationLearned 或 PerceiverForImageClassificationFourier 或 PerceiverForImageClassificationConvProcessing (感知器模型)
- poolformer — PoolFormerForImageClassification (PoolFormer 模型)
- pvt — PvtForImageClassification (PVT 模型)
- pvt_v2 — PvtV2ForImageClassification (PVTv2 模型)
- regnet — RegNetForImageClassification (RegNet 模型)
- resnet — ResNetForImageClassification (ResNet 模型)
- segformer — SegformerForImageClassification (SegFormer 模型)
- siglip — SiglipForImageClassification (SigLIP 模型)
- swiftformer — SwiftFormerForImageClassification (SwiftFormer 模型)
- swin — SwinForImageClassification (Swin Transformer 模型)
- swinv2 — Swinv2ForImageClassification (Swin Transformer V2 模型)
- van — VanForImageClassification(VAN 模型)
- vit — ViTForImageClassification (ViT 模型)
- vit_hybrid — ViTHybridForImageClassification (ViT 混合模型)
- vit_msn — ViTMSNForImageClassification (ViTMSN 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForImageClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForImageClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有图像分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- ConvNextConfig configuration class: TFConvNextForImageClassification (ConvNeXT model)
- ConvNextV2Config configuration class: TFConvNextV2ForImageClassification (ConvNeXTV2 model)
- CvtConfig configuration class: TFCvtForImageClassification (CvT model)
- Data2VecVisionConfig configuration class: TFData2VecVisionForImageClassification (Data2VecVision model)
- DeiTConfig configuration class: TFDeiTForImageClassification or TFDeiTForImageClassificationWithTeacher (DeiT model)
- EfficientFormerConfig configuration class: TFEfficientFormerForImageClassification or TFEfficientFormerForImageClassificationWithTeacher (EfficientFormer model)
- MobileViTConfig configuration class: TFMobileViTForImageClassification (MobileViT model)
- RegNetConfig configuration class: TFRegNetForImageClassification (RegNet model)
- ResNetConfig configuration class: TFResNetForImageClassification (ResNet model)
- SegformerConfig configuration class: TFSegformerForImageClassification (SegFormer model)
- SwiftFormerConfig configuration class: TFSwiftFormerForImageClassification (SwiftFormer model)
- SwinConfig configuration class: TFSwinForImageClassification (Swin Transformer model)
- ViTConfig configuration class: TFViTForImageClassification (ViT model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有图像分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在 您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有图像分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- convnext — TFConvNextForImageClassification (ConvNeXT 模型)
- convnextv2 — TFConvNextV2ForImageClassification (ConvNeXTV2 模型)
- cvt — TFCvtForImageClassification (CvT 模型)
- data2vec-vision — TFData2VecVisionForImageClassification (Data2VecVision 模型)
- deit — TFDeiTForImageClassification 或 TFDeiTForImageClassificationWithTeacher (DeiT 模型)
- efficientformer — TFEfficientFormerForImageClassification 或 TFEfficientFormerForImageClassificationWithTeacher (EfficientFormer 模型)
- mobilevit — TFMobileViTForImageClassification (MobileViT 模型)
- regnet — TFRegNetForImageClassification (RegNet 模型)
- resnet — TFResNetForImageClassification (ResNet 模型)
- segformer — TFSegformerForImageClassification (SegFormer 模型)
- swiftformer — TFSwiftFormerForImageClassification (SwiftFormer 模型)
- swin — TFSwinForImageClassification (Swin Transformer 模型)
- vit — TFViTForImageClassification (ViT 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForImageClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForImageClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有图像分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BeitConfig configuration class: FlaxBeitForImageClassification (BEiT model)
- Dinov2Config configuration class: FlaxDinov2ForImageClassification (DINOv2 model)
- RegNetConfig configuration class: FlaxRegNetForImageClassification (RegNet model)
- ResNetConfig configuration class: FlaxResNetForImageClassification (ResNet model)
- ViTConfig configuration class: FlaxViTForImageClassification (ViT model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有图像分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,因此revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有图像分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- beit — FlaxBeitForImageClassification (BEiT 模型)
- dinov2 — FlaxDinov2ForImageClassification (DINOv2 模型)
- regnet — FlaxRegNetForImageClassification (RegNet 模型)
- resnet — FlaxResNetForImageClassification (ResNet 模型)
- vit — FlaxViTForImageClassification (ViT 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForImageClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForVideoClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有视频分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- TimesformerConfig configuration class: TimesformerForVideoClassification (TimeSformer model)
- VideoMAEConfig configuration class: VideoMAEForVideoClassification (VideoMAE model)
- VivitConfig configuration class: VivitForVideoClassification (ViViT model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有视频分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, optional, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在 Hub 上使用的特定代码版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有视频分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- timesformer — TimesformerForVideoClassification (TimeSformer 模型)
- videomae — VideoMAEForVideoClassification (VideoMAE 模型)
- vivit — VivitForVideoClassification (ViViT 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForVideoClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForVideoClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForVideoClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForVideoClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForKeypointDetection
AutoModelForMaskedImageModeling
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有掩码图像建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- DeiTConfig configuration class: DeiTForMaskedImageModeling (DeiT model)
- FocalNetConfig configuration class: FocalNetForMaskedImageModeling (FocalNet model)
- SwinConfig configuration class: SwinForMaskedImageModeling (Swin Transformer model)
- Swinv2Config configuration class: Swinv2ForMaskedImageModeling (Swin Transformer V2 model)
- ViTConfig configuration class: ViTForMaskedImageModeling (ViT model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有掩码图像建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将被传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
oros.PathLike
, optional) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有掩码图像建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- deit — DeiTForMaskedImageModeling (DeiT 模型)
- focalnet — FocalNetForMaskedImageModeling (FocalNet 模型)
- swin — SwinForMaskedImageModeling (Swin Transformer 模型)
- swinv2 — Swinv2ForMaskedImageModeling (Swin Transformer V2 模型)
- vit — ViTForMaskedImageModeling (ViT 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForMaskedImageModeling
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMaskedImageModeling.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMaskedImageModeling
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有掩码图像建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- DeiTConfig 配置类: TFDeiTForMaskedImageModeling (DeiT 模型)
- SwinConfig 配置类: TFSwinForMaskedImageModeling (Swin Transformer 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有掩码图像建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, optional) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有掩码图像建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- deit — TFDeiTForMaskedImageModeling (DeiT 模型)
- swin — TFSwinForMaskedImageModeling (Swin Transformer 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForMaskedImageModeling
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForMaskedImageModeling.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForObjectDetection
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有对象检测头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- ConditionalDetrConfig configuration class: ConditionalDetrForObjectDetection (Conditional DETR model)
- DeformableDetrConfig configuration class: DeformableDetrForObjectDetection (Deformable DETR model)
- DetaConfig configuration class: DetaForObjectDetection (DETA model)
- DetrConfig configuration class: DetrForObjectDetection (DETR model)
- RTDetrConfig configuration class: RTDetrForObjectDetection (RT-DETR model)
- TableTransformerConfig configuration class: TableTransformerForObjectDetection (Table Transformer model)
- YolosConfig configuration class: YolosForObjectDetection (YOLOS model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有目标检测头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有目标检测头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- conditional_detr — ConditionalDetrForObjectDetection (条件DETR模型)
- deformable_detr — DeformableDetrForObjectDetection (可变形DETR模型)
- deta — DetaForObjectDetection (DETA 模型)
- detr — DetrForObjectDetection (DETR 模型)
- rt_detr — RTDetrForObjectDetection (RT-DETR 模型)
- table-transformer — TableTransformerForObjectDetection (表格 Transformer 模型)
- yolos — YolosForObjectDetection (YOLOS 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForObjectDetection
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForObjectDetection.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForObjectDetection.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForObjectDetection.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForImageSegmentation
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有图像分割头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- DetrConfig 配置类: DetrForSegmentation (DETR 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有图像分割头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能时都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每次请求时使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有图像分割头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- detr — DetrForSegmentation (DETR 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForImageSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForImageSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForImageSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForImageSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForImageToImage
AutoModelForSemanticSegmentation
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有语义分割头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BeitConfig configuration class: BeitForSemanticSegmentation (BEiT model)
- DPTConfig configuration class: DPTForSemanticSegmentation (DPT model)
- Data2VecVisionConfig configuration class: Data2VecVisionForSemanticSegmentation (Data2VecVision model)
- MobileNetV2Config configuration class: MobileNetV2ForSemanticSegmentation (MobileNetV2 model)
- MobileViTConfig configuration class: MobileViTForSemanticSegmentation (MobileViT model)
- MobileViTV2Config configuration class: MobileViTV2ForSemanticSegmentation (MobileViTV2 model)
- SegformerConfig configuration class: SegformerForSemanticSegmentation (SegFormer model)
- UperNetConfig configuration class: UperNetForSemanticSegmentation (UPerNet model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有语义分割头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有语义分割头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- beit — BeitForSemanticSegmentation (BEiT 模型)
- data2vec-vision — Data2VecVisionForSemanticSegmentation (Data2VecVision 模型)
- dpt — DPTForSemanticSegmentation (DPT 模型)
- mobilenet_v2 — MobileNetV2ForSemanticSegmentation (MobileNetV2 模型)
- mobilevit — MobileViTForSemanticSegmentation (MobileViT 模型)
- mobilevitv2 — MobileViTV2ForSemanticSegmentation (MobileViTV2 模型)
- segformer — SegformerForSemanticSegmentation (SegFormer 模型)
- upernet — UperNetForSemanticSegmentation (UPerNet 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForSemanticSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForSemanticSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSemanticSegmentation
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有语义分割头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- Data2VecVisionConfig configuration class: TFData2VecVisionForSemanticSegmentation (Data2VecVision model)
- MobileViTConfig configuration class: TFMobileViTForSemanticSegmentation (MobileViT model)
- SegformerConfig configuration class: TFSegformerForSemanticSegmentation (SegFormer model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有语义分割头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在其自己的建模文件中。此选项 应仅对您信任的存储库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在 Hub 上使用的特定代码版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有语义分割头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- data2vec-vision — TFData2VecVisionForSemanticSegmentation (Data2VecVision 模型)
- mobilevit — TFMobileViTForSemanticSegmentation (MobileViT 模型)
- segformer — TFSegformerForSemanticSegmentation (SegFormer 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForSemanticSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForInstanceSegmentation
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有实例分割头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
基于配置类选择要实例化的模型类:
- MaskFormerConfig 配置类: MaskFormerForInstanceSegmentation (MaskFormer 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有实例分割头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有实例分割头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- maskformer — MaskFormerForInstanceSegmentation (MaskFormer 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForInstanceSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForInstanceSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForInstanceSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForInstanceSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForUniversalSegmentation
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(具有通用图像分割头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- DetrConfig configuration class: DetrForSegmentation (DETR model)
- Mask2FormerConfig configuration class: Mask2FormerForUniversalSegmentation (Mask2Former model)
- MaskFormerConfig configuration class: MaskFormerForInstanceSegmentation (MaskFormer model)
- OneFormerConfig configuration class: OneFormerForUniversalSegmentation (OneFormer model)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置实例化库中的一个模型类(带有通用图像分割头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有通用图像分割头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- detr — DetrForSegmentation (DETR 模型)
- mask2former — Mask2FormerForUniversalSegmentation (Mask2Former 模型)
- maskformer — MaskFormerForInstanceSegmentation (MaskFormer 模型)
- oneformer — OneFormerForUniversalSegmentation (OneFormer 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForUniversalSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForUniversalSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForUniversalSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForUniversalSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForZeroShotImageClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有零样本图像分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlignConfig configuration class: AlignModel (ALIGN model)
- AltCLIPConfig configuration class: AltCLIPModel (AltCLIP model)
- Blip2Config configuration class: Blip2ForImageTextRetrieval (BLIP-2 model)
- BlipConfig configuration class: BlipModel (BLIP model)
- CLIPConfig configuration class: CLIPModel (CLIP model)
- CLIPSegConfig configuration class: CLIPSegModel (CLIPSeg model)
- ChineseCLIPConfig configuration class: ChineseCLIPModel (Chinese-CLIP model)
- SiglipConfig configuration class: SiglipModel (SigLIP model)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置实例化库中的一个模型类(带有零样本图像分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有零样本图像分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- align — AlignModel (ALIGN 模型)
- altclip — AltCLIPModel (AltCLIP 模型)
- blip — BlipModel (BLIP 模型)
- blip-2 — Blip2ForImageTextRetrieval (BLIP-2 模型)
- chinese_clip — ChineseCLIPModel (中文-CLIP 模型)
- clip — CLIPModel (CLIP 模型)
- clipseg — CLIPSegModel (CLIPSeg 模型)
- siglip — SiglipModel (SigLIP 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForZeroShotImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForZeroShotImageClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForZeroShotImageClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有零样本图像分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- BlipConfig 配置类: TFBlipModel (BLIP 模型)
- CLIPConfig 配置类: TFCLIPModel (CLIP 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置实例化库中的一个模型类(带有零样本图像分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载都会尽可能恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型文件定义的模型。此选项应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,因此revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有零样本图像分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- blip — TFBlipModel (BLIP 模型)
- clip — TFCLIPModel (CLIP 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForZeroShotImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForZeroShotImageClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForZeroShotObjectDetection
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有零样本目标检测头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- GroundingDinoConfig configuration class: GroundingDinoForObjectDetection (Grounding DINO model)
- OmDetTurboConfig configuration class: OmDetTurboForObjectDetection (OmDet-Turbo model)
- OwlViTConfig configuration class: OwlViTForObjectDetection (OWL-ViT model)
- Owlv2Config configuration class: Owlv2ForObjectDetection (OWLv2 model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置实例化库中的一个模型类(带有零样本目标检测头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,因此revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有零样本目标检测头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- grounding-dino — GroundingDinoForObjectDetection (Grounding DINO 模型)
- omdet-turbo — OmDetTurboForObjectDetection (OmDet-Turbo 模型)
- owlv2 — Owlv2ForObjectDetection (OWLv2 模型)
- owlvit — OwlViTForObjectDetection (OWL-ViT 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForZeroShotObjectDetection
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
音频
以下自动类别适用于以下音频任务。
AutoModelForAudioClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有音频分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- ASTConfig configuration class: ASTForAudioClassification (Audio Spectrogram Transformer model)
- Data2VecAudioConfig configuration class: Data2VecAudioForSequenceClassification (Data2VecAudio model)
- HubertConfig configuration class: HubertForSequenceClassification (Hubert model)
- SEWConfig configuration class: SEWForSequenceClassification (SEW model)
- SEWDConfig configuration class: SEWDForSequenceClassification (SEW-D model)
- UniSpeechConfig configuration class: UniSpeechForSequenceClassification (UniSpeech model)
- UniSpeechSatConfig configuration class: UniSpeechSatForSequenceClassification (UniSpeechSat model)
- Wav2Vec2BertConfig configuration class: Wav2Vec2BertForSequenceClassification (Wav2Vec2-BERT model)
- Wav2Vec2Config configuration class: Wav2Vec2ForSequenceClassification (Wav2Vec2 model)
- Wav2Vec2ConformerConfig configuration class: Wav2Vec2ConformerForSequenceClassification (Wav2Vec2-Conformer model)
- WavLMConfig configuration class: WavLMForSequenceClassification (WavLM model)
- WhisperConfig configuration class: WhisperForAudioClassification (Whisper model)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置实例化库中的一个模型类(带有音频分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将被传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型文件定义的模型。此选项应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有音频分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- audio-spectrogram-transformer — ASTForAudioClassification (音频频谱变换器模型)
- data2vec-audio — Data2VecAudioForSequenceClassification (Data2VecAudio 模型)
- hubert — HubertForSequenceClassification (Hubert 模型)
- sew — SEWForSequenceClassification (SEW 模型)
- sew-d — SEWDForSequenceClassification (SEW-D 模型)
- unispeech — UniSpeechForSequenceClassification (UniSpeech 模型)
- unispeech-sat — UniSpeechSatForSequenceClassification (UniSpeechSat 模型)
- wav2vec2 — Wav2Vec2ForSequenceClassification (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertForSequenceClassification (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerForSequenceClassification (Wav2Vec2-Conformer 模型)
- wavlm — WavLMForSequenceClassification (WavLM 模型)
- whisper — WhisperForAudioClassification (Whisper 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForAudioClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForAudioClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForAudioFrameClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有音频分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- Wav2Vec2Config 配置类: TFWav2Vec2ForSequenceClassification (Wav2Vec2 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有音频分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能时都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有音频分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- wav2vec2 — TFWav2Vec2ForSequenceClassification (Wav2Vec2 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForAudioClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForAudioClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
TFAutoModelForAudioFrameClassification
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有音频帧(token)分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- Data2VecAudioConfig configuration class: Data2VecAudioForAudioFrameClassification (Data2VecAudio model)
- UniSpeechSatConfig configuration class: UniSpeechSatForAudioFrameClassification (UniSpeechSat model)
- Wav2Vec2BertConfig configuration class: Wav2Vec2BertForAudioFrameClassification (Wav2Vec2-BERT model)
- Wav2Vec2Config configuration class: Wav2Vec2ForAudioFrameClassification (Wav2Vec2 model)
- Wav2Vec2ConformerConfig configuration class: Wav2Vec2ConformerForAudioFrameClassification (Wav2Vec2-Conformer model)
- WavLMConfig configuration class: WavLMForAudioFrameClassification (WavLM model)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有音频帧(token)分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, 默认为"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有音频帧(token)分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- data2vec-audio — Data2VecAudioForAudioFrameClassification (Data2VecAudio 模型)
- unispeech-sat — UniSpeechSatForAudioFrameClassification (UniSpeechSat 模型)
- wav2vec2 — Wav2Vec2ForAudioFrameClassification (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertForAudioFrameClassification (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerForAudioFrameClassification (Wav2Vec2-Conformer 模型)
- wavlm — WavLMForAudioFrameClassification (WavLM 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForAudioFrameClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForAudioFrameClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForAudioFrameClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForAudioFrameClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForCTC
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有连接主义时间分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- Data2VecAudioConfig configuration class: Data2VecAudioForCTC (Data2VecAudio model)
- HubertConfig configuration class: HubertForCTC (Hubert model)
- MCTCTConfig configuration class: MCTCTForCTC (M-CTC-T model)
- SEWConfig configuration class: SEWForCTC (SEW model)
- SEWDConfig configuration class: SEWDForCTC (SEW-D model)
- UniSpeechConfig configuration class: UniSpeechForCTC (UniSpeech model)
- UniSpeechSatConfig configuration class: UniSpeechSatForCTC (UniSpeechSat model)
- Wav2Vec2BertConfig configuration class: Wav2Vec2BertForCTC (Wav2Vec2-BERT model)
- Wav2Vec2Config configuration class: Wav2Vec2ForCTC (Wav2Vec2 model)
- Wav2Vec2ConformerConfig configuration class: Wav2Vec2ConformerForCTC (Wav2Vec2-Conformer model)
- WavLMConfig configuration class: WavLMForCTC (WavLM model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有连接主义时间分类头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有连接主义时间分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- data2vec-audio — Data2VecAudioForCTC (Data2VecAudio 模型)
- hubert — HubertForCTC (Hubert 模型)
- mctct — MCTCTForCTC (M-CTC-T 模型)
- sew — SEWForCTC (SEW 模型)
- sew-d — SEWDForCTC (SEW-D 模型)
- unispeech — UniSpeechForCTC (UniSpeech 模型)
- unispeech-sat — UniSpeechSatForCTC (UniSpeechSat 模型)
- wav2vec2 — Wav2Vec2ForCTC (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertForCTC (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerForCTC (Wav2Vec2-Conformer 模型)
- wavlm — WavLMForCTC (WavLM 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForCTC
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForCTC.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForCTC.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForCTC.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForSpeechSeq2Seq
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有序列到序列的语音到文本建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- Pop2PianoConfig configuration class: Pop2PianoForConditionalGeneration (Pop2Piano model)
- SeamlessM4TConfig configuration class: SeamlessM4TForSpeechToText (SeamlessM4T model)
- SeamlessM4Tv2Config configuration class: SeamlessM4Tv2ForSpeechToText (SeamlessM4Tv2 model)
- Speech2TextConfig configuration class: Speech2TextForConditionalGeneration (Speech2Text model)
- SpeechEncoderDecoderConfig configuration class: SpeechEncoderDecoderModel (Speech Encoder decoder model)
- SpeechT5Config configuration class: SpeechT5ForSpeechToText (SpeechT5 model)
- WhisperConfig configuration class: WhisperForConditionalGeneration (Whisper model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有序列到序列的语音到文本建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能时都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有序列到序列的语音到文本建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- pop2piano — Pop2PianoForConditionalGeneration (Pop2Piano 模型)
- seamless_m4t — SeamlessM4TForSpeechToText (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4Tv2ForSpeechToText (SeamlessM4Tv2 模型)
- speech-encoder-decoder — SpeechEncoderDecoderModel (语音编码解码模型)
- speech_to_text — Speech2TextForConditionalGeneration (Speech2Text 模型)
- speecht5 — SpeechT5ForSpeechToText (SpeechT5 模型)
- whisper — WhisperForConditionalGeneration (Whisper 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForSpeechSeq2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForSpeechSeq2Seq.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSpeechSeq2Seq
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有序列到序列的语音到文本建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- Speech2TextConfig 配置类: TFSpeech2TextForConditionalGeneration (Speech2Text 模型)
- WhisperConfig 配置类: TFWhisperForConditionalGeneration (Whisper 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有序列到序列的语音到文本建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从PyTorch检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有序列到序列的语音到文本建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- speech_to_text — TFSpeech2TextForConditionalGeneration (语音转文本模型)
- whisper — TFWhisperForConditionalGeneration (Whisper 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForSpeechSeq2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForSpeechSeq2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSpeechSeq2Seq
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有序列到序列的语音到文本建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- SpeechEncoderDecoderConfig 配置类: FlaxSpeechEncoderDecoderModel (语音编码解码模型)
- WhisperConfig 配置类: FlaxWhisperForConditionalGeneration (Whisper 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有序列到序列的语音到文本建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有序列到序列的语音到文本建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- speech-encoder-decoder — FlaxSpeechEncoderDecoderModel (语音编码解码模型)
- whisper — FlaxWhisperForConditionalGeneration (Whisper 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForSpeechSeq2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForAudioXVector
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有通过x-vector头进行的音频检索)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- Data2VecAudioConfig configuration class: Data2VecAudioForXVector (Data2VecAudio model)
- UniSpeechSatConfig configuration class: UniSpeechSatForXVector (UniSpeechSat model)
- Wav2Vec2BertConfig configuration class: Wav2Vec2BertForXVector (Wav2Vec2-BERT model)
- Wav2Vec2Config configuration class: Wav2Vec2ForXVector (Wav2Vec2 model)
- Wav2Vec2ConformerConfig configuration class: Wav2Vec2ConformerForXVector (Wav2Vec2-Conformer model)
- WavLMConfig configuration class: WavLMForXVector (WavLM model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(通过x-vector头部进行音频检索)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载都会在可能的情况下恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(通过x-vector头进行音频检索)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- data2vec-audio — Data2VecAudioForXVector (Data2VecAudio 模型)
- unispeech-sat — UniSpeechSatForXVector (UniSpeechSat 模型)
- wav2vec2 — Wav2Vec2ForXVector (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertForXVector (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerForXVector (Wav2Vec2-Conformer 模型)
- wavlm — WavLMForXVector (WavLM 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForAudioXVector
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForAudioXVector.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForAudioXVector.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForAudioXVector.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForTextToSpectrogram
AutoModelForTextToWaveform
多模态
以下自动类别适用于以下多模态任务。
AutoModelForTableQuestionAnswering
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有表格问答头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- TapasConfig 配置类: TapasForQuestionAnswering (TAPAS 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有表格问答头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每次请求时使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的存储库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有表格问答头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- tapas — TapasForQuestionAnswering (TAPAS 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")
>>> # Update configuration during loading
>>> model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/tapas_tf_model_config.json")
>>> model = AutoModelForTableQuestionAnswering.from_pretrained(
... "./tf_model/tapas_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForTableQuestionAnswering
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有表格问答头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- TapasConfig 配置类: TFTapasForQuestionAnswering (TAPAS 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有表格问答头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型文件定义的模型。此选项应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有表格问答头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- tapas — TFTapasForQuestionAnswering (TAPAS 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForTableQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")
>>> # Update configuration during loading
>>> model = TFAutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/tapas_pt_model_config.json")
>>> model = TFAutoModelForTableQuestionAnswering.from_pretrained(
... "./pt_model/tapas_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForDocumentQuestionAnswering
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有文档问答头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- LayoutLMConfig configuration class: LayoutLMForQuestionAnswering (LayoutLM model)
- LayoutLMv2Config configuration class: LayoutLMv2ForQuestionAnswering (LayoutLMv2 model)
- LayoutLMv3Config configuration class: LayoutLMv3ForQuestionAnswering (LayoutLMv3 model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有文档问答头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
示例:
>>> from transformers import AutoConfig, AutoModelForDocumentQuestionAnswering
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> model = AutoModelForDocumentQuestionAnswering.from_config(config)
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理会在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有文档问答头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- layoutlm — LayoutLMForQuestionAnswering (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2ForQuestionAnswering (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForDocumentQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> # Update configuration during loading
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/layoutlm_tf_model_config.json")
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained(
... "./tf_model/layoutlm_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForDocumentQuestionAnswering
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有文档问答头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- LayoutLMConfig 配置类: TFLayoutLMForQuestionAnswering (LayoutLM 模型)
- LayoutLMv3Config 配置类: TFLayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有文档问答头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
示例:
>>> from transformers import AutoConfig, TFAutoModelForDocumentQuestionAnswering
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> model = TFAutoModelForDocumentQuestionAnswering.from_config(config)
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将被传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其余部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有文档问答头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- layoutlm — TFLayoutLMForQuestionAnswering (LayoutLM 模型)
- layoutlmv3 — TFLayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForDocumentQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> # Update configuration during loading
>>> model = TFAutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/layoutlm_pt_model_config.json")
>>> model = TFAutoModelForDocumentQuestionAnswering.from_pretrained(
... "./pt_model/layoutlm_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForVisualQuestionAnswering
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有视觉问答头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- Blip2Config configuration class: Blip2ForConditionalGeneration (BLIP-2 model)
- BlipConfig configuration class: BlipForQuestionAnswering (BLIP model)
- ViltConfig configuration class: ViltForQuestionAnswering (ViLT model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有视觉问答头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型文件定义的模型。此选项应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有视觉问答头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- blip — BlipForQuestionAnswering (BLIP 模型)
- blip-2 — Blip2ForConditionalGeneration (BLIP-2 模型)
- vilt — ViltForQuestionAnswering (ViLT 模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForVisualQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForVisualQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
>>> # Update configuration during loading
>>> model = AutoModelForVisualQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/vilt_tf_model_config.json")
>>> model = AutoModelForVisualQuestionAnswering.from_pretrained(
... "./tf_model/vilt_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForVision2Seq
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有视觉到文本建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- Blip2Config configuration class: Blip2ForConditionalGeneration (BLIP-2 model)
- BlipConfig configuration class: BlipForConditionalGeneration (BLIP model)
- ChameleonConfig configuration class: ChameleonForConditionalGeneration (Chameleon model)
- GitConfig configuration class: GitForCausalLM (GIT model)
- Idefics2Config configuration class: Idefics2ForConditionalGeneration (Idefics2 model)
- Idefics3Config configuration class: Idefics3ForConditionalGeneration (Idefics3 model)
- InstructBlipConfig configuration class: InstructBlipForConditionalGeneration (InstructBLIP model)
- InstructBlipVideoConfig configuration class: InstructBlipVideoForConditionalGeneration (InstructBlipVideo model)
- Kosmos2Config configuration class: Kosmos2ForConditionalGeneration (KOSMOS-2 model)
- LlavaConfig configuration class: LlavaForConditionalGeneration (LLaVa model)
- LlavaNextConfig configuration class: LlavaNextForConditionalGeneration (LLaVA-NeXT model)
- LlavaNextVideoConfig configuration class: LlavaNextVideoForConditionalGeneration (LLaVa-NeXT-Video model)
- LlavaOnevisionConfig configuration class: LlavaOnevisionForConditionalGeneration (LLaVA-Onevision model)
- MllamaConfig configuration class: MllamaForConditionalGeneration (Mllama model)
- PaliGemmaConfig configuration class: PaliGemmaForConditionalGeneration (PaliGemma model)
- Pix2StructConfig configuration class: Pix2StructForConditionalGeneration (Pix2Struct model)
- Qwen2VLConfig configuration class: Qwen2VLForConditionalGeneration (Qwen2VL model)
- VideoLlavaConfig configuration class: VideoLlavaForConditionalGeneration (VideoLlava model)
- VipLlavaConfig configuration class: VipLlavaForConditionalGeneration (VipLlava model)
- VisionEncoderDecoderConfig configuration class: VisionEncoderDecoderModel (Vision Encoder decoder model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有视觉到文本建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从TensorFlow检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有视觉到文本建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- blip — BlipForConditionalGeneration (BLIP 模型)
- blip-2 — Blip2ForConditionalGeneration (BLIP-2 模型)
- chameleon — ChameleonForConditionalGeneration (变色龙模型)
- git — GitForCausalLM (GIT 模型)
- idefics2 — Idefics2ForConditionalGeneration (Idefics2 模型)
- idefics3 — Idefics3ForConditionalGeneration (Idefics3 模型)
- instructblip — InstructBlipForConditionalGeneration (InstructBLIP 模型)
- instructblipvideo — InstructBlipVideoForConditionalGeneration (InstructBlipVideo 模型)
- kosmos-2 — Kosmos2ForConditionalGeneration (KOSMOS-2 模型)
- llava — LlavaForConditionalGeneration (LLaVa 模型)
- llava_next — LlavaNextForConditionalGeneration (LLaVA-NEXT 模型)
- llava_next_video — LlavaNextVideoForConditionalGeneration (LLaVa-NeXT-Video 模型)
- llava_onevision — LlavaOnevisionForConditionalGeneration (LLaVA-Onevision 模型)
- mllama — MllamaForConditionalGeneration (Mllama 模型)
- paligemma — PaliGemmaForConditionalGeneration (PaliGemma 模型)
- pix2struct — Pix2StructForConditionalGeneration (Pix2Struct 模型)
- qwen2_vl — Qwen2VLForConditionalGeneration (Qwen2VL 模型)
- video_llava — VideoLlavaForConditionalGeneration (VideoLlava 模型)
- vipllava — VipLlavaForConditionalGeneration (VipLlava 模型)
- vision-encoder-decoder — VisionEncoderDecoderModel(视觉编码解码模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForVision2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForVision2Seq.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForVision2Seq
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有视觉到文本建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- BlipConfig 配置类: TFBlipForConditionalGeneration (BLIP 模型)
- VisionEncoderDecoderConfig 配置类: TFVisionEncoderDecoderModel (视觉编码器解码器模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置实例化库中的一个模型类(带有视觉到文本建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args(额外的位置参数,可选)—
将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每次请求时都会被使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, optional, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有视觉到文本建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- blip — TFBlipForConditionalGeneration (BLIP 模型)
- vision-encoder-decoder — TFVisionEncoderDecoderModel (视觉编码解码模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForVision2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForVision2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForVision2Seq
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有视觉到文本建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
要实例化的模型类是基于配置类选择的:
- VisionEncoderDecoderConfig 配置类: FlaxVisionEncoderDecoderModel (视觉编码解码模型)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有视觉到文本建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (额外的位置参数, 可选) —
将传递给底层模型的
__init__()
方法. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其他部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有视觉到文本建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- vision-encoder-decoder — FlaxVisionEncoderDecoderModel (视觉编码器解码器模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForVision2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForVision2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForImageTextToText
这是一个通用模型类,当使用from_pretrained()类方法或from_config()类方法创建时,它将实例化为库中的一个模型类(带有图像-文本到文本建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- Blip2Config configuration class: Blip2ForConditionalGeneration (BLIP-2 model)
- BlipConfig configuration class: BlipForConditionalGeneration (BLIP model)
- ChameleonConfig configuration class: ChameleonForConditionalGeneration (Chameleon model)
- FuyuConfig configuration class: FuyuForCausalLM (Fuyu model)
- GitConfig configuration class: GitForCausalLM (GIT model)
- Idefics2Config configuration class: Idefics2ForConditionalGeneration (Idefics2 model)
- Idefics3Config configuration class: Idefics3ForConditionalGeneration (Idefics3 model)
- IdeficsConfig configuration class: IdeficsForVisionText2Text (IDEFICS model)
- InstructBlipConfig configuration class: InstructBlipForConditionalGeneration (InstructBLIP model)
- Kosmos2Config configuration class: Kosmos2ForConditionalGeneration (KOSMOS-2 model)
- LlavaConfig configuration class: LlavaForConditionalGeneration (LLaVa model)
- LlavaNextConfig configuration class: LlavaNextForConditionalGeneration (LLaVA-NeXT model)
- LlavaOnevisionConfig configuration class: LlavaOnevisionForConditionalGeneration (LLaVA-Onevision model)
- MllamaConfig configuration class: MllamaForConditionalGeneration (Mllama model)
- PaliGemmaConfig configuration class: PaliGemmaForConditionalGeneration (PaliGemma model)
- Pix2StructConfig configuration class: Pix2StructForConditionalGeneration (Pix2Struct model)
- PixtralVisionConfig configuration class: LlavaForConditionalGeneration (Pixtral model)
- Qwen2VLConfig configuration class: Qwen2VLForConditionalGeneration (Qwen2VL model)
- UdopConfig configuration class: UdopForConditionalGeneration (UDOP model)
- VipLlavaConfig configuration class: VipLlavaForConditionalGeneration (VipLlava model)
- VisionEncoderDecoderConfig configuration class: VisionEncoderDecoderModel (Vision Encoder decoder model)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动的"eager"
实现。
从配置中实例化库中的一个模型类(带有图像-文本到文本的建模头)。
注意: 从配置文件中加载模型不会加载模型权重。它只会影响模型的配置。使用from_pretrained()来加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args(额外的位置参数,可选)—
将被传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
如果你想从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,你应该检查使用save_pretrained()和from_pretrained()是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 - resume_download — 已弃用并被忽略。现在默认情况下,所有下载在可能的情况下都会自动恢复。 将在Transformers的v5版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 使用的特定模型版本。它可以是分支名称、标签名称或提交ID,因为我们使用基于git的系统在huggingface.co上存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在Hub上使用自定义模型定义在自己的建模文件中。此选项 应仅对您信任的仓库设置为True
,并且您已阅读其代码,因为它将 在您的本地机器上执行Hub上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码存储在与模型其余部分不同的仓库中,则用于指定在Hub上使用的特定代码版本。它可以是分支名称、标签名称或提交ID,因为我们在huggingface.co上使用基于git的系统来存储模型和其他工件,所以revision
可以是git允许的任何标识符。 - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有图像-文本到文本建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(要么作为参数传递,要么在可能的情况下从pretrained_model_name_or_path
加载),或者当缺少该属性时,通过回退到对pretrained_model_name_or_path
进行模式匹配来选择:
- blip — BlipForConditionalGeneration (BLIP 模型)
- blip-2 — Blip2ForConditionalGeneration (BLIP-2 模型)
- chameleon — ChameleonForConditionalGeneration (变色龙模型)
- fuyu — FuyuForCausalLM (Fuyu 模型)
- git — GitForCausalLM (GIT 模型)
- idefics — IdeficsForVisionText2Text (IDEFICS 模型)
- idefics2 — Idefics2ForConditionalGeneration (Idefics2 模型)
- idefics3 — Idefics3ForConditionalGeneration (Idefics3 模型)
- instructblip — InstructBlipForConditionalGeneration (InstructBLIP 模型)
- kosmos-2 — Kosmos2ForConditionalGeneration (KOSMOS-2 模型)
- llava — LlavaForConditionalGeneration (LLaVa 模型)
- llava_next — LlavaNextForConditionalGeneration (LLaVA-NEXT 模型)
- llava_onevision — LlavaOnevisionForConditionalGeneration (LLaVA-Onevision 模型)
- mllama — MllamaForConditionalGeneration (Mllama 模型)
- paligemma — PaliGemmaForConditionalGeneration (PaliGemma 模型)
- pix2struct — Pix2StructForConditionalGeneration (Pix2Struct 模型)
- pixtral — LlavaForConditionalGeneration (Pixtral 模型)
- qwen2_vl — Qwen2VLForConditionalGeneration (Qwen2VL 模型)
- udop — UdopForConditionalGeneration (UDOP 模型)
- vipllava — VipLlavaForConditionalGeneration (VipLlava 模型)
- vision-encoder-decoder — VisionEncoderDecoderModel(视觉编码解码模型)
模型默认使用model.eval()
设置为评估模式(例如,dropout模块被停用)。要训练模型,您应首先使用model.train()
将其设置回训练模式。
示例:
>>> from transformers import AutoConfig, AutoModelForImageTextToText
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForImageTextToText.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForImageTextToText.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForImageTextToText.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )