sktime.transformations.panel.dwt 源代码

"""Discrete wavelet transform."""

import math

import numpy as np
import pandas as pd

from sktime.datatypes import convert
from sktime.transformations.base import BaseTransformer

__author__ = ["vnicholson1"]


[文档]class DWTTransformer(BaseTransformer): """Discrete Wavelet Transform Transformer. Performs the Haar wavelet transformation on a time series. Parameters ---------- num_levels : int, number of levels to perform the Haar wavelet transformation. Example ------- >>> from sktime.transformations.panel.dwt import DWTTransformer >>> from sktime.datasets import load_airline >>> from sktime.datatypes import convert >>> >>> y = load_airline() >>> y = convert(y, to="Panel") >>> transformer = DWTTransformer(num_levels=3) >>> y_transformed = transformer.fit_transform(y) """ _tags = { "authors": "vnicholson1", "scitype:transform-input": "Series", # what is the scitype of X: Series, or Panel "scitype:transform-output": "Series", # what scitype is returned: Primitives, Series, Panel "scitype:instancewise": False, # is this an instance-wise transform? "X_inner_mtype": "nested_univ", # which mtypes do _fit/_predict support for X? "y_inner_mtype": "None", # which mtypes do _fit/_predict support for X? "fit_is_empty": True, } def __init__(self, num_levels=3): self.num_levels = num_levels super().__init__() def _transform(self, X, y=None): """Transform X and return a transformed version. private _transform containing core logic, called from transform Parameters ---------- X : nested pandas DataFrame of shape [n_instances, n_features] each cell of X must contain pandas.Series Data to fit transform to y : ignored argument for interface compatibility Additional data, e.g., labels for transformation Returns ------- Xt : nested pandas DataFrame of shape [n_instances, n_features] each cell of Xt contains pandas.Series transformed version of X """ self._check_parameters() # Get information about the dataframe col_names = X.columns Xt = pd.DataFrame() for x in col_names: # Convert one of the columns in the dataframe to numpy array arr = convert( pd.DataFrame(X[x]), from_type="nested_univ", to_type="numpyflat", as_scitype="Panel", ) transformedData = self._extract_wavelet_coefficients(arr) # Convert to a numpy array transformedData = np.asarray(transformedData) # Add it to the dataframe colToAdd = [] for i in range(len(transformedData)): inst = transformedData[i] colToAdd.append(pd.Series(inst)) Xt[x] = colToAdd return Xt def _extract_wavelet_coefficients(self, data): """Extract wavelet coefficients of a 2d array of time series. The coefficients correspond to the wavelet coefficients from levels 1 to num_levels followed by the approximation coefficients of the highest level. """ num_levels = self.num_levels res = [] for x in data: if num_levels == 0: res.append(x) else: coeffs = [] current = x approx = None for _ in range(num_levels): approx = self._get_approx_coefficients(current) wav_coeffs = self._get_wavelet_coefficients(current) current = approx wav_coeffs.reverse() coeffs.extend(wav_coeffs) approx.reverse() coeffs.extend(approx) coeffs.reverse() res.append(coeffs) return res def _check_parameters(self): """Check the values of parameters passed to DWT. Throws ------ ValueError or TypeError if a parameters input is invalid. """ if isinstance(self.num_levels, int): if self.num_levels <= -1: raise ValueError("num_levels must have the value" + "of at least 0") else: raise TypeError( "num_levels must be an 'int'. Found" + "'" + type(self.num_levels).__name__ + "' instead." ) def _get_approx_coefficients(self, arr): """Get the approximate coefficients at a given level.""" new = [] if len(arr) == 1: return [arr[0]] for x in range(math.floor(len(arr) / 2)): new.append((arr[2 * x] + arr[2 * x + 1]) / math.sqrt(2)) return new def _get_wavelet_coefficients(self, arr): """Get the wavelet coefficients at a given level.""" new = [] # if length is 1, just return the list back if len(arr) == 1: return [arr[0]] for x in range(math.floor(len(arr) / 2)): new.append((arr[2 * x] - arr[2 * x + 1]) / math.sqrt(2)) return new