训练数据集基类

class darts.utils.data.training_dataset.DualCovariatesTrainingDataset[源代码]

基类:TrainingDataset, ABC

用于 DualCovariatesTorchModel 训练数据集的抽象类。它包含 (past_target, historic_future_covariates, future_covariates, static_covariates, future_target) 的 4-元组 np.ndarray。协变量是可选的,可以是 None

class darts.utils.data.training_dataset.FutureCovariatesTrainingDataset[源代码]

基类:TrainingDataset, ABC

用于 FutureCovariatesTorchModel 训练数据集的抽象类。它包含 (past_target, future_covariate, static_covariates, future_target) 的 3-元组 np.ndarray。协变量是可选的,可以是 None

class darts.utils.data.training_dataset.MixedCovariatesTrainingDataset[源代码]

基类:TrainingDataset, ABC

混合协变量Torch模型训练数据集的抽象类。它包含 (过去目标, 过去协变量, 历史未来协变量, 未来协变量, 静态协变量, 未来目标) 的 5-元组 np.ndarray。协变量是可选的,可以是 None

class darts.utils.data.training_dataset.PastCovariatesTrainingDataset[源代码]

基类:TrainingDataset, ABC

用于训练 PastCovariatesTorchModel 的抽象类数据集。它包含 (past_target, past_covariate, static_covariates, future_target) 的 3-元组 np.ndarray。协变量是可选的,可以是 None

class darts.utils.data.training_dataset.SplitCovariatesTrainingDataset[源代码]

基类:TrainingDataset, ABC

用于 SplitCovariatesTorchModel 训练数据集的抽象类。它包含 (past_target, past_covariates, future_covariates, static_covariates, future_target) 的 4 元组 np.ndarray。协变量是可选的,可以是 None

class darts.utils.data.training_dataset.TrainingDataset[源代码]

基类:ABC, Dataset

Darts 中所有 torch 模型训练数据集的超类。这些包括

  • “PastCovariates”数据集(用于PastCovariatesTorchModel):包含(past_target,

    past_covariates, static_covariates, future_target)

  • “FutureCovariates” 数据集(用于 FutureCovariatesTorchModel):包含(past_target,

    future_covariates, static_covariates, future_target)

  • “DualCovariates” 数据集(用于 DualCovariatesTorchModel):包含(past_target,

    historic_future_covariates, future_covariates, static_covariates, future_target)

  • “MixedCovariates” 数据集(用于 MixedCovariatesTorchModel):包含(past_target,

    past_covariates, historic_future_covariates, future_covariates, static_covariates, future_target)

  • “SplitCovariates” 数据集(用于 SplitCovariatesTorchModel):包含(past_target,

    past_covariates, future_covariates, static_covariates, future_target)

协变量是可选的,可以是 None

这是用于训练(或验证)的,除了 future_target 之外的所有数据都表示模型输入(future_target 是模型训练来预测的输出)。

Darts 的 TorchForecastingModel 可以通过 fit_from_dataset() 方法从正确类型的 TrainingDataset 实例中进行拟合。

TrainingDataset 继承了 torch 的 Dataset;这意味着实现必须提供 __getitem__() 方法。

它包含 np.ndarray`(而不是 `TimeSeries),因为训练只需要数值,因此当我们通过返回 TimeSeries 底层数据的 numpy 视图来进行切片时,可以获得较大的性能提升。