Keras 3 API 文档 / 内置小型数据集 / MNIST数字分类数据集

MNIST数字分类数据集

[source]

load_data function

keras.datasets.mnist.load_data(path="mnist.npz")

加载MNIST数据集.

这是一个包含60,000张28x28灰度图像的数据集,图像内容为10个数字, 以及一个包含10,000张图像的测试集. 更多信息可以在MNIST主页找到.

参数: path: 本地缓存数据集的路径 (相对于~/.keras/datasets).

返回: 由NumPy数组组成的元组: (x_train, y_train), (x_test, y_test).

x_train: uint8 NumPy数组,包含灰度图像数据,形状为 (60000, 28, 28),包含训练数据.像素值范围 从0到255.

y_train: uint8 NumPy数组,包含数字标签(范围0-9的整数) 形状为(60000,),用于训练数据.

x_test: uint8 NumPy数组,包含灰度图像数据,形状为 (10000, 28, 28),包含测试数据.像素值范围 从0到255.

y_test: uint8 NumPy数组,包含数字标签(范围0-9的整数) 形状为(10000,),用于测试数据.

示例:

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)

许可证:

Yann LeCun和Corinna Cortes持有MNIST数据集的版权, 该数据集是原始NIST数据集的衍生作品. MNIST数据集根据知识共享署名-相同方式共享3.0许可提供.