load_svmlight_files#

sklearn.datasets.load_svmlight_files(files, *, n_features=None, dtype=<class 'numpy.float64'>, multilabel=False, zero_based='auto', query_id=False, offset=0, length=-1)#

从多个文件加载数据集,文件格式为SVMlight格式。

此函数等效于对文件列表映射load_svmlight_file,不同之处在于结果被连接成一个扁平列表,并且样本向量被约束为具有相同数量的特征。

如果文件包含成对偏好约束(在svmlight格式中称为“qid”),除非query_id参数设置为True,否则这些约束将被忽略。这些成对偏好约束可以用于在使用成对损失函数时约束样本的组合(如在某些学习排序问题中),以便只考虑具有相同query_id值的对。

Parameters:
filesarray-like, dtype=str, path-like, file-like or int

(路径)要加载的文件。如果路径以“.gz”或“.bz2”结尾,将在运行时解压缩。如果传递了一个整数,则假定为文件描述符。文件类和文件描述符将不会被此函数关闭。文件类对象必须以二进制模式打开。

Changed in version 1.2: 现在接受路径类对象。

n_featuresint, default=None

要使用的特征数量。如果为None,将从任何文件中出现的最大列索引推断出来。

这可以设置为比输入文件中实际特征数量更高的值,但设置为较低的值将导致引发异常。

dtypenumpy数据类型, default=np.float64

要加载的数据集的数据类型。这将是要输出的numpy数组 Xy 的数据类型。

multilabelbool, default=False

样本可以有多个标签(参见 https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html)。

zero_basedbool or “auto”, default=”auto”

f中的列索引是否为零基(True)或一基(False)。如果列索引为一基,则将其转换为零基以匹配Python/NumPy约定。 如果设置为“auto”,将应用启发式检查以根据文件内容确定这一点。这两种文件在野外都会出现,但不幸的是它们不是自标识的。使用“auto”或True时,如果没有传递偏移量或长度,应该总是安全的。 如果传递了偏移量或长度,“auto”模式将回退到zero_based=True,以避免启发式检查在文件的不同段上产生不一致的结果。

query_idbool, default=False

如果为True,将返回每个文件的query_id数组。

offsetint, default=0

通过向前查找忽略第一个偏移字节,然后丢弃直到下一个换行字符的字节。

lengthint, default=-1

如果严格为正,一旦文件中的位置达到(偏移量 + 长度)字节阈值,将停止读取任何新行数据。

Returns:
[X1, y1, …, Xn, yn] or [X1, y1, q1, …, Xn, yn, qn]: 数组列表

每个(Xi, yi)对是load_svmlight_file(files[i])的结果。 如果query_id设置为True,这将返回(Xi, yi, qi)三元组。

See also

load_svmlight_file

用于加载单个文件的类似函数。

Notes

当将模型拟合到矩阵X_train并对矩阵X_test进行评估时,确保X_train和X_test具有相同数量的特征(X_train.shape[1] == X_test.shape[1])是至关重要的。如果您使用load_svmlight_file单独加载文件,情况可能并非如此。

Examples

使用joblib.Memory缓存svmlight文件:

from joblib import Memory
from sklearn.datasets import load_svmlight_file
mem = Memory("./mycache")

@mem.cache
def get_data():
    data_train, target_train, data_test, target_test = load_svmlight_files(
        ["svmlight_file_train", "svmlight_file_test"]
    )
    return data_train, target_train, data_test, target_test

X_train, y_train, X_test, y_test = get_data()