Keras 3 API 文档 / 层 API / 核心层 / InputSpec 对象

InputSpec 对象

[source]

InputSpec class

keras.InputSpec(
    dtype=None,
    shape=None,
    ndim=None,
    max_ndim=None,
    min_ndim=None,
    axes=None,
    allow_last_axis_squeeze=False,
    name=None,
    optional=False,
)

指定每一层输入的秩、数据类型和形状.

层可以公开(如果合适)一个 input_spec 属性:一个 InputSpec 实例,或一个嵌套结构的 InputSpec 实例(每个输入张量一个).这些对象使层能够为 Layer.__call__ 的第一个参数运行输入结构、输入秩、输入形状和输入数据类型的输入兼容性检查.

形状中的 None 条目与任何维度兼容.

参数: dtype: 预期的输入数据类型. shape: 形状元组,预期的输入形状(可能包括 None 表示动态轴). 包括批量大小. ndim: 整数,预期的输入秩. max_ndim: 整数,输入的最大秩. min_ndim: 整数,输入的最小秩. axes: 映射整数轴到特定维度值的字典. allow_last_axis_squeeze: 如果为 True,允许秩为 N+1 的输入,只要输入的最后一个轴为 1,以及秩为 N-1 的输入,只要规范的最后一个轴为 1. name: 当以字典形式传递数据时,与此输入对应的预期键. optional: 布尔值,输入是否可选.可选输入可以接受 None 值.

示例:

class MyLayer(Layer):
    def __init__(self):
        super().__init__()
        # 该层将接受形状为 (*, 28, 28) 和 (*, 28, 28, 1) 的输入
        # 并在其他情况下引发适当的错误消息.
        self.input_spec = InputSpec(
            shape=(None, 28, 28, 1),
            allow_last_axis_squeeze=True)