Keras 3 API 文档 / 内置小型数据集 / IMDB电影评论情感分类数据集

IMDB电影评论情感分类数据集

[source]

load_data function

keras.datasets.imdb.load_data(
    path="imdb.npz",
    num_words=None,
    skip_top=0,
    maxlen=None,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
    **kwargs
)

加载IMDB数据集.

这是一个包含25,000条来自IMDB的电影评论数据集,按情感(正面/负面)标记.评论已经过预处理,每条评论被编码为单词索引(整数)列表.为了方便起见,单词按其在数据集中的总体频率进行索引,因此例如整数"3”表示数据中第3个最常见的单词.这允许快速过滤操作,例如:"仅考虑前10,000个最常见的单词,但排除前20个最常见的单词”.

按照惯例,"0”不代表特定单词,而是用于编码填充标记.

参数: path: 缓存数据的位置(相对于~/.keras/dataset). num_words: 整数或None.单词按其出现频率(在训练集中)排序,只保留num_words个最常见的单词.任何不常见的单词将在序列数据中显示为oov_char值.如果为None,则保留所有单词.默认为None. skip_top: 跳过前N个最常见的单词(这些单词可能不具有信息量).这些单词将在数据集中显示为oov_char值.当为0时,不跳过任何单词.默认为0. maxlen: 整数或None.最大序列长度.任何较长的序列将被截断.None表示不截断.默认为None. seed: 整数.用于可重复数据混洗的种子. start_char: 整数.序列的开始将用此字符标记.0通常是填充字符.默认为1. oov_char: 整数.词汇表外字符.由于num_wordsskip_top限制而被剪切的单词将替换为此字符. index_from: 整数.实际单词从此索引及以上开始索引.

返回: Numpy数组元组: (x_train, y_train), (x_test, y_test).

x_train, x_test: 序列列表,即索引(整数)列表.如果指定了num_words参数,则可能的最大索引值为num_words - 1.如果指定了maxlen参数,则可能的最大序列长度为maxlen.

y_train, y_test: 整数标签列表(1或0).

注意: "词汇表外”字符仅用于在训练集中出现但在当前未达到num_words阈值的单词.在训练集中未见但在测试集中出现的单词已被简单跳过.


[source]

get_word_index function

keras.datasets.imdb.get_word_index(path="imdb_word_index.json")

检索一个映射单词到其在IMDB数据集中索引的字典.

参数: path: 缓存数据的位置(相对于 ~/.keras/dataset).

返回: 单词索引字典.键是单词字符串,值是它们的索引.

示例:

# 使用默认参数调用 keras.datasets.imdb.load_data
start_char = 1
oov_char = 2
index_from = 3
# 检索训练序列.
(x_train, _), _ = keras.datasets.imdb.load_data(
    start_char=start_char, oov_char=oov_char, index_from=index_from
)
# 检索单词索引文件映射单词到索引
word_index = keras.datasets.imdb.get_word_index()
# 反转单词索引以获得映射索引到单词的字典
# 并向索引添加 `index_from` 以与 `x_train` 同步
inverted_word_index = dict(
    (i + index_from, word) for (word, i) in word_index.items()
)
# 更新 `inverted_word_index` 以包括 `start_char` 和 `oov_char`
inverted_word_index[start_char] = "[START]"
inverted_word_index[oov_char] = "[OOV]"
# 解码数据集中的第一个序列
decoded_sequence = " ".join(inverted_word_index[i] for i in x_train[0])