导出 🤗 Transformers 模型到 ONNX
🤗 Transformers 提供了一个 transformers.onnx
包,使您能够通过利用配置对象将模型检查点转换为 ONNX 图。
查看有关导出🤗 Transformers模型的指南以获取更多详细信息。
ONNX 配置
我们提供了三个抽象类,您应该根据您希望导出的模型架构类型来继承它们:
- 基于编码器的模型继承自 OnnxConfig
- 基于解码器的模型继承自 OnnxConfigWithPast
- 编码器-解码器模型继承自 OnnxSeq2SeqConfigWithPast
OnnxConfig
类 transformers.onnx.OnnxConfig
< source >( config: PretrainedConfig task: str = '默认' patching_specs: typing.List[transformers.onnx.config.PatchingSpec] = 无 )
用于描述如何通过ONNX格式导出模型的元数据的ONNX可导出模型的基类。
flatten_output_collection_property
< source >( name: str field: typing.Iterable[typing.Any] ) → (Dict[str, Any])
展平任何潜在的嵌套结构,使用结构中元素的索引扩展字段名称。
from_model_config
< source >( config: PretrainedConfig task: str = 'default' )
为特定模型实例化一个OnnxConfig
generate_dummy_inputs
< source >( preprocessor: typing.Union[ForwardRef('PreTrainedTokenizerBase'), ForwardRef('FeatureExtractionMixin'), ForwardRef('ImageProcessingMixin')] batch_size: int = -1 seq_length: int = -1 num_choices: int = -1 is_pair: bool = False framework: typing.Optional[transformers.utils.generic.TensorType] = None num_channels: int = 3 image_width: int = 40 image_height: int = 40 sampling_rate: int = 22050 time_duration: float = 5.0 frequency: int = 220 tokenizer: PreTrainedTokenizerBase = None )
参数
- 预处理器 — (PreTrainedTokenizerBase, FeatureExtractionMixin, or ImageProcessingMixin): 与此模型配置相关的预处理器。
- batch_size (
int
, optional, defaults to -1) — 导出模型的批量大小(-1 表示动态轴)。 - num_choices (
int
, optional, defaults to -1) — 为多项选择任务提供的候选答案数量(-1 表示动态轴)。 - seq_length (
int
, optional, defaults to -1) — 导出模型的序列长度(-1 表示动态轴)。 - is_pair (
bool
, optional, defaults toFalse
) — 指示输入是否为一对(句子1,句子2) - 框架 (
TensorType
, 可选, 默认为None
) — 分词器将为其生成张量的框架(PyTorch 或 TensorFlow)。 - num_channels (
int
, optional, defaults to 3) — 生成图像的通道数。 - image_width (
int
, optional, 默认为 40) — 生成图像的宽度。 - image_height (
int
, optional, 默认为 40) — 生成图像的高度。 - sampling_rate (
int
, optional 默认为 22050) — 用于音频数据生成的采样率。 - time_duration (
float
, optional 默认为 5.0) — 音频数据生成的总采样秒数。 - 频率 (
int
, 可选 默认为 220) — 生成音频的期望自然频率。
生成输入以提供给特定框架的ONNX导出器
generate_dummy_inputs_onnxruntime
< source >( reference_model_inputs: typing.Mapping[str, typing.Any] ) → Mapping[str, Tensor]
使用参考模型输入生成ONNX运行时的输入。覆盖此方法以运行具有编码器和解码器分别导出为单独ONNX文件的seq2seq模型的推理。
标志指示模型是否需要使用外部数据格式
OnnxConfigWithPast
类 transformers.onnx.OnnxConfigWithPast
< source >( config: PretrainedConfig task: str = '默认' patching_specs: typing.List[transformers.onnx.config.PatchingSpec] = 无 use_past: bool = 假 )
fill_with_past_key_values_
< source >( inputs_or_outputs: typing.Mapping[str, typing.Mapping[int, str]] 方向: str inverted_values_shape: bool = False )
考虑到过去的键值动态轴,填充input_or_outputs映射。
with_past
< source >( config: PretrainedConfig task: str = 'default' )
实例化一个带有use_past
属性设置为True的OnnxConfig
OnnxSeq2SeqConfigWithPast
类 transformers.onnx.OnnxSeq2SeqConfigWithPast
< source >( config: PretrainedConfig task: str = '默认' patching_specs: typing.List[transformers.onnx.config.PatchingSpec] = 无 use_past: bool = 假 )
ONNX 特性
每个ONNX配置都与一组功能相关联,这些功能使您能够为不同类型的拓扑或任务导出模型。
功能管理器
check_supported_model_or_raise
< source >( model: typing.Union[ForwardRef('PreTrainedModel'), ForwardRef('TFPreTrainedModel')] feature: str = 'default' )
检查模型是否具有所请求的功能。
determine_framework
< source >( model: str framework: str = None )
确定用于导出的框架。
优先级按以下顺序排列:
- 用户通过
framework
输入。 - 如果提供了本地检查点,则使用与检查点相同的框架。
- 环境中可用的框架,优先使用PyTorch
获取配置
< source >( model_type: str feature: str ) → OnnxConfig
获取模型类型和特征组合的OnnxConfig。
get_model_class_for_feature
< source >( feature: str framework: str = 'pt' )
尝试从特征名称中检索AutoModel类。
get_model_from_feature
< source >( feature: str model: str framework: str = None cache_dir: str = None )
尝试从模型名称和要启用的功能中检索模型。
获取支持的模型类型功能
< source >( model_type: str model_name: typing.Optional[str] = None )
尝试从模型类型中检索特征 -> OnnxConfig 构造函数映射。