Skip to content
import os
import random
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import plotly.express as px
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode
import requests
# %加载自动重新加载扩展
# %自动重新加载 2
from IPython.display import Math
from pytorch_tabular import TabularModel
from pytorch_tabular.models import (
    CategoryEmbeddingModelConfig,
    GatedAdditiveTreeEnsembleConfig,
    MDNConfig
)
from pytorch_tabular.config import (
    DataConfig,
    OptimizerConfig,
    TrainerConfig,
    ExperimentConfig,
)
# 来自 pytorch_tabular.categorical_encoders 的 CategoricalEmbeddingTransformer
from pytorch_tabular.models.common.heads import LinearHeadConfig, MixtureDensityHeadConfig
np.random.seed(42)

实用函数

def generate_linear_example(samples=int(1e5)):
    x_data = np.random.sample(samples)[:, np.newaxis].astype(np.float32)
    y_data = np.add(5*x_data, np.multiply((x_data)**2, np.random.standard_normal(x_data.shape)))

    x_train, x_valid, y_train, y_valid = train_test_split(x_data, y_data, test_size=0.5, random_state=42)
    x_test = np.linspace(0.,1.,int(1e3))[:, np.newaxis].astype(np.float32)
    df_train = pd.DataFrame({"col1": x_train.ravel(), "target": y_train.ravel()})
    df_valid = pd.DataFrame({"col1": x_valid.ravel(), "target": y_valid.ravel()})
    # test = sorted(df_valid.col1.round(3).unique())
    # df_test = pd.DataFrame({"col1": test})
    df_test = pd.DataFrame({"col1": x_test.ravel()})
    return (df_train, df_valid, df_test, ["target"])

def generate_non_linear_example(samples=int(1e5)):
    x_data = np.float32(np.random.uniform(-10, 10, (1, samples)))
    r_data = np.array([np.random.normal(scale=np.abs(i)) for i in x_data])
    y_data = np.float32(np.square(x_data)+r_data*2.0)

    x_data2 = np.float32(np.random.uniform(-10, 10, (1, samples)))
    r_data2 = np.array([np.random.normal(scale=np.abs(i)) for i in x_data2])
    y_data2 = np.float32(-np.square(x_data2)+r_data2*2.0)

    x_data = np.concatenate((x_data,x_data2),axis=1).T
    y_data = np.concatenate((y_data,y_data2),axis=1).T

    min_max_scaler = MinMaxScaler()
    y_data = min_max_scaler.fit_transform(y_data)

    x_train, x_valid, y_train, y_valid = train_test_split(x_data, y_data, test_size=0.5, random_state=42, shuffle=True)
    x_test = np.linspace(-10,10,int(1e3))[:, np.newaxis].astype(np.float32)
    df_train = pd.DataFrame({"col1": x_train.ravel(), "target": y_train.ravel()})
    df_valid = pd.DataFrame({"col1": x_valid.ravel(), "target": y_valid.ravel()})
    # test = sorted(df_valid.col1.round(3).unique())
    # df_test = pd.DataFrame({"col1": test})
    df_test = pd.DataFrame({"col1": x_test.ravel()})
    return (df_train, df_valid, df_test, ["target"])

def generate_step_linear_example(samples=int(1e5)):
    x_data = np.random.sample(samples)[:, np.newaxis].astype(np.float32)
    y_data = np.zeros(x_data.shape)
    mask = x_data<0.5
    y_data[mask] = np.add(5*x_data[mask], np.multiply((x_data[mask])**2, np.random.standard_normal(x_data[mask].shape)))
    y_data[~mask] = np.add(100*x_data[~mask]+x_data[~mask]**2 , np.multiply((x_data[~mask])**2, np.random.standard_normal(x_data[~mask].shape)))
    min_max_scaler = MinMaxScaler()
    y_data = min_max_scaler.fit_transform(y_data)

    x_train, x_valid, y_train, y_valid = train_test_split(x_data, y_data, test_size=0.5, random_state=42, shuffle=True)
    x_test = np.linspace(0.,1.,int(1e3))[:, np.newaxis].astype(np.float32)
    df_train = pd.DataFrame({"col1": x_train.ravel(), "target": y_train.ravel()})
    df_valid = pd.DataFrame({"col1": x_valid.ravel(), "target": y_valid.ravel()})
    # test = sorted(df_valid.col1.round(3).unique())
    # df_test = pd.DataFrame({"col1": test})
    df_test = pd.DataFrame({"col1": x_test.ravel()})
    return (df_train, df_valid, df_test, ["target"])

def generate_gaussian_mixture(samples=int(1e5)):
    x_data = np.random.sample(samples)[:, np.newaxis].astype(np.float32)
    pi = np.sin(x_data)+3*x_data*np.cos(x_data)
    pi = pi/pi.max()
    # g1 = np.random.sample(samples) * 4 * x_data.squeeze()
    # g2 = np.random.sample(samples) * 15 * x_data.squeeze()
    g1 = 2*x_data.squeeze() + 0.5*np.random.sample(samples)
    g2 = 8*x_data.squeeze() + 0.5*np.random.sample(samples)

    y_data = pi.round().squeeze()*g1 + (1-pi.round().squeeze())*g2
    y_data = y_data.reshape(-1,1)
    x_train, x_valid, y_train, y_valid = train_test_split(x_data, y_data, test_size=0.5, random_state=42)
    x_test = np.linspace(0.,1.,int(1e3))[:, np.newaxis].astype(np.float32)
    df_train = pd.DataFrame({"col1": x_train.ravel(), "target": y_train.ravel()})
    df_valid = pd.DataFrame({"col1": x_valid.ravel(), "target": y_valid.ravel()})
    # test = sorted(df_valid.col1.round(3).unique())
    # df_test = pd.DataFrame({"col1": test})
    df_test = pd.DataFrame({"col1": x_test.ravel()})
    return (df_train, df_valid, df_test, ["target"])

def latex_to_png( formula, file):
    tfile = file
    r = requests.get( 'http://latex.codecogs.com/png.latex?\dpi{300} \huge %s' % formula )
    with open( tfile, 'wb' ) as f:
#     f = open( tfile, 'wb' )
        f.write( r.content )
# f.close()

线性示例

df_train, df_valid, df_test, target_col = generate_linear_example()

绘图

# display(Math(r"$y = 5x + (x^2 * \epsilon)$"+"\n"+r"$\epsilon \backsim \mathcal{N}(0,1)$"))
fig = px.scatter(df_train, x="col1", y="target", title=r"$y = 5x + (x^2 * \epsilon)$"+"\n"+r"$\epsilon \backsim \mathcal{N}(0,1)$")
fig.update_layout(
    title={
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'})
fig.write_image("imgs/prob_reg_fig_1.png")

fig = px.histogram(df_train, x="target", title="Histogram")
fig.write_image("imgs/prob_reg_hist_1.png")

训练MDN

定义配置

epochs = 15
batch_size = 128
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=['col1'],
    categorical_cols=[],
# 连续特征变换="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # 运行LRFinder以自动推导学习率
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping="valid_loss",
    early_stopping_patience=5,
    checkpoints="valid_loss"
)

optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})

mdn_head_config = MixtureDensityHeadConfig(num_gaussian=1).__dict__

backbone_config_class = "CategoryEmbeddingModelConfig"
backbone_config = dict(
    task="backbone",
    layers="128-64",  # 每一层的节点数量
    activation="ReLU",  # 各层之间的激活
    initialization="kaiming"
)

model_config = MDNConfig(
    task="regression",
    backbone_config_class=backbone_config_class,
    backbone_config_params=backbone_config,
    head_config=mdn_head_config,
    learning_rate=1e-3,
)
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

训练模型

tabular_model.fit(train=df_train, validation=df_valid)
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/seed.py:48: LightningDeprecationWarning:

`pytorch_lightning.utilities.seed.seed_everything` has been deprecated in v1.8.0 and will be removed in v1.10.0. Please use `lightning_lite.utilities.seed.seed_everything` instead.

Global seed set to 42
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:198: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

`head` is not a valid parameter for backbone task. Making `head=None`
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/torch/cuda/__init__.py:497: UserWarning:

Can't initialize NVML

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=98` reached.
LR finder stopped early after 98 steps due to diverging loss.
Learning rate set to 0.0009120108393559097
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_fe815ad0-1801-45e6-9700-6b960f3736da.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_fe815ad0-1801-45e6-9700-6b960f3736da.ckpt

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 8.5 K 
1 | _embedding_layer | Embedding1dLayer          | 2     
2 | _head            | MixtureDensityHead        | 194   
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
8.7 K     Trainable params
0         Non-trainable params
8.7 K     Total params
0.035     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
`Trainer.fit` stopped: `max_epochs=15` reached.
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7f00c9e47f10>

预测和可视化

pred_df = tabular_model.predict(df_test, quantiles=[0.25,0.5,0.75], n_samples=100)
pred_df.head()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`


Generating Predictions...:   0%|          | 0/8 [00:00<?, ?it/s]
col1 target_prediction target_q25 target_q50 target_q75
0 0.000000 0.426740 0.149168 0.443037 0.678223
1 0.001001 0.403517 0.169063 0.382718 0.627589
2 0.002002 0.500595 0.189833 0.531668 0.795271
3 0.003003 0.433998 0.172423 0.431257 0.682992
4 0.004004 0.463966 0.220691 0.469996 0.699911
fig = go.Figure([
    go.Scatter(
        name='Mean',
        x=pred_df['col1'],
        y=pred_df['target_prediction'],
        mode='lines',
        line=dict(color='rgba(28,53,94,1)'),
    ),
    go.Scatter(
        name='Upper Bound',
        x=pred_df['col1'],
        y=pred_df['target_q75'],
        mode='lines',
        marker=dict(color='rgba(0,147,201,0.3)'),
        line=dict(width=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound',
        x=pred_df['col1'],
        y=pred_df['target_q25'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(0,147,201,0.3)',
        fill='tonexty',
        showlegend=False
    )
])
fig.update_layout(
    yaxis_title='y',
    title='Mixture Density Network Prediction',
    hovermode="x"
)
# fig.show()
fig.write_image("imgs/prob_reg_mdn_1.png")

非线性示例

df_train, df_valid, df_test, target_col = generate_non_linear_example()

绘图

fig = px.scatter(df_train, x="col1", y="target", title=r"$y = \pm x^2 + \epsilon$"+"\n"+r"$\epsilon\backsim\mathcal{N}(0,|x|)$")
fig.update_layout(
    title={
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'})
fig.write_image("imgs/prob_reg_fig_2.png")

fig = px.histogram(df_train, x="target", title="Histogram")
fig.write_image("imgs/prob_reg_hist_2.png")

训练一个前馈神经网络

定义配置项

epochs = 200
batch_size = 2048
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=['col1'],
    categorical_cols=[],
# continuous_feature_transform="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # 运行LRFinder以自动推导学习率。
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping="valid_loss",
    early_stopping_patience=5,
    checkpoints="valid_loss"
)
optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})
model_config = CategoryEmbeddingModelConfig(
    task="regression",
    layers="16-8",  # 每层中的节点数量
    activation="ReLU",  # 各层之间的激活
    head="LinearHead",
    learning_rate=1e-3,
)
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

训练模型

tabular_model.fit(train=df_train, validation=df_valid)
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/seed.py:48: LightningDeprecationWarning:

`pytorch_lightning.utilities.seed.seed_everything` has been deprecated in v1.8.0 and will be removed in v1.10.0. Please use `lightning_lite.utilities.seed.seed_everything` instead.

Global seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=100` reached.
Learning rate set to 0.012022644346174132
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_50589bdd-1fbd-4a3c-810e-22c399ce6d75.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_50589bdd-1fbd-4a3c-810e-22c399ce6d75.ckpt

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 168   
1 | _embedding_layer | Embedding1dLayer          | 2     
2 | head             | LinearHead                | 9     
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
179       Trainable params
0         Non-trainable params
179       Total params
0.001     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:48: UserWarning:

Detected KeyboardInterrupt, attempting graceful shutdown...

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7f01afe8a590>

预测与可视化

pred_df = tabular_model.predict(df_valid.sample(1000).sort_values("col1"))
pred_df.head()
Generating Predictions...:   0%|          | 0/1 [00:00<?, ?it/s]
col1 target target_prediction
52696 -9.971139 0.234386 0.505665
84854 -9.920759 0.198878 0.505669
58540 -9.911808 0.726360 0.505670
9330 -9.907802 0.805257 0.505671
54741 -9.893517 0.807486 0.505672
fig = go.Figure([
    go.Scatter(
        name='Prediction',
        x=pred_df['col1'],
        y=pred_df['target_prediction'],
        mode='lines',
        line=dict(color='rgba(28,53,94,1)'),
    ),
    go.Scatter(
        name='Actual',
        x=pred_df['col1'],
        y=pred_df['target'],
        mode='markers',
        line=dict(color='rgba(60,180,229,1)'),
    ),
])
fig.update_layout(
    yaxis_title='y',
    title='Category Embedding Prediction',
    hovermode="x"
)
# fig.show()
fig.write_image("imgs/prob_reg_non_mdn_2.png")

训练MDN

定义配置

epochs = 200
batch_size = 2048
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=['col1'],
    categorical_cols=[],
# continuous_feature_transform=连续特征转换"quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=False, # 运行LRFinder以自动推导学习率
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping="valid_loss",
    early_stopping_patience=5,
    checkpoints="valid_loss"
)
optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})

mdn_head_config = MixtureDensityHeadConfig(num_gaussian=2, weight_regularization=2, lambda_mu=10, lambda_pi=5).__dict__

backbone_config_class = "CategoryEmbeddingModelConfig"
backbone_config = dict(
    task="backbone",
    layers="128-64",  # 每层中的节点数量
    activation="ReLU",  # 各层之间的激活
    head=None,
)

model_config = MDNConfig(
    task="regression",
    backbone_config_class=backbone_config_class,
    backbone_config_params=backbone_config,
    head_config=mdn_head_config,
    learning_rate=1e-3,
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

训练模型

tabular_model.fit(train=df_train, validation=df_valid)
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/seed.py:48: LightningDeprecationWarning:

`pytorch_lightning.utilities.seed.seed_everything` has been deprecated in v1.8.0 and will be removed in v1.10.0. Please use `lightning_lite.utilities.seed.seed_everything` instead.

Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 8.5 K 
1 | _embedding_layer | Embedding1dLayer          | 2     
2 | _head            | MixtureDensityHead        | 388   
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
8.9 K     Trainable params
0         Non-trainable params
8.9 K     Total params
0.036     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:48: UserWarning:

Detected KeyboardInterrupt, attempting graceful shutdown...

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7f7fc50ef8e0>

预测与可视化

pred_df = tabular_model.predict(df_test, quantiles=[0.25,0.5,0.75], n_samples=100, ret_logits=True)
pred_df.head()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`


Generating Predictions...:   0%|          | 0/1 [00:00<?, ?it/s]
col1 target_prediction target_q25 target_q50 target_q75 pi_0 pi_1 sigma_0 sigma_1 mu_0 ... backbone_features_54 backbone_features_55 backbone_features_56 backbone_features_57 backbone_features_58 backbone_features_59 backbone_features_60 backbone_features_61 backbone_features_62 backbone_features_63
0 -10.000000 0.329563 0.227722 0.277360 0.316893 0.000037 0.000020 0.055194 0.054383 0.255731 ... 0.0 0.0 4.998535 1.226815 0.0 3.516335 0.0 1.714176 2.625350 4.461626
1 -9.979980 0.652835 0.648949 0.704733 0.751973 0.000037 0.000020 0.055109 0.054314 0.256374 ... 0.0 0.0 4.999834 1.226926 0.0 3.506327 0.0 1.722667 2.624132 4.448001
2 -9.959960 0.671895 0.659679 0.706704 0.753247 0.000037 0.000019 0.055025 0.054245 0.257018 ... 0.0 0.0 5.001133 1.227038 0.0 3.496319 0.0 1.731157 2.622915 4.434377
3 -9.939939 0.488492 0.255757 0.619482 0.704288 0.000037 0.000019 0.054941 0.054176 0.257662 ... 0.0 0.0 5.002433 1.227150 0.0 3.486310 0.0 1.739647 2.621696 4.420753
4 -9.919920 0.696155 0.677998 0.725608 0.762295 0.000037 0.000019 0.054858 0.054107 0.258305 ... 0.0 0.0 5.003733 1.227261 0.0 3.476303 0.0 1.748137 2.620478 4.407130

5 rows × 75 columns

df = df_valid.sample(10000)
fig = go.Figure([
    go.Scatter(
        name='Ground Truth',
        x=df['col1'],
        y=df['target'],
        mode='markers',
        line=dict(color='rgba(153, 115, 142, 0.2)'),
    ),
    go.Scatter(
        name='Component 1',
        x=pred_df['col1'],
        y=pred_df['mu_0'],
        mode='lines',
        line=dict(color='rgba(36, 37, 130, 1)'),
    ),
    go.Scatter(
        name='Component 2',
        x=pred_df['col1'],
        y=pred_df['mu_1'],
        mode='lines',
        line=dict(color='rgba(246, 76, 114, 1)'),
    ),
    go.Scatter(
        name='Upper Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']+pred_df['sigma_0'],
        mode='lines',
        marker=dict(color='rgba(47, 47, 162, 0.5)'),
# 线宽=字典(宽度=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']-pred_df['sigma_0'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(47, 47, 162, 0.5)',
        fill='tonexty',
        showlegend=False
    ),
    go.Scatter(
        name='Upper Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']+pred_df['sigma_1'],
        mode='lines',
        marker=dict(color='rgba(250, 152, 174, 0.5)'),
# 线宽=字典(宽度=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']-pred_df['sigma_1'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(250, 152, 174, 0.5)',
        fill='tonexty',
        showlegend=False
    ),
])
fig.update_layout(
    yaxis_title='y',
    title='Mixture Density Network Prediction',
    hovermode="x"
)
# 图表展示
fig.write_image("imgs/prob_reg_mdn_2.png")

高斯混合模型

df_train, df_valid, df_test, target_col = generate_gaussian_mixture()

绘图

from IPython.display import display, Math, Latex
eqn = r'$\pi = \frac{sin(x) + 3xcos(x)}{max \left (sin(x) + 3xcos(x) \right )} \\ \\ g1 = 2x + 0.5 \epsilon \rightarrow  \epsilon \backsim \mathcal{N}(0,1) \\ g2 = 8x + 0.5 \epsilon \rightarrow  \epsilon \backsim \mathcal{N}(0,1) \\ p = Bernoulli(pi) \rightarrow \text{Samples one of two outcomes based on the value of } \pi \\ y = p \times g1 + (1-p) \times g2$'
display(Math(eqn))
fig = px.scatter(df_train, x="col1", y="target")
fig.update_layout(
    title={
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'})
fig.write_image("imgs/prob_reg_fig_3.png")
$\displaystyle \pi = \frac{sin(x) + 3xcos(x)}{max \left (sin(x) + 3xcos(x) \right )} \\ \\ g1 = 2x + 0.5 \epsilon \rightarrow \epsilon \backsim \mathcal{N}(0,1) \\ g2 = 8x + 0.5 \epsilon \rightarrow \epsilon \backsim \mathcal{N}(0,1) \\ p = Bernoulli(pi) \rightarrow \text{Samples one of two outcomes based on the value of } \pi \\ y = p \times g1 + (1-p) \times g2$

fig = px.histogram(df_train, x="target", title="Histogram")
fig.write_image("imgs/prob_reg_hist_3.png")

训练前馈神经网络

定义配置

epochs = 200
batch_size = 2048
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=['col1'],
    categorical_cols=[],
# 连续特征变换="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # 运行 LRFinder 以自动推导学习率
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping="valid_loss",
    early_stopping_patience=5,
    checkpoints="valid_loss"
)

optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})

model_config = CategoryEmbeddingModelConfig(
    task="regression",
    layers="16-8",  # 每层节点数
    activation="ReLU",  # 各层之间的激活
    head="LinearHead",
    learning_rate=1e-3,
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

训练模型

tabular_model.fit(train=df_train, validation=df_valid)
Global seed set to 42
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:198: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=100` reached.
Learning rate set to 0.01
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_37f44922-c147-4423-999a-61166be4c5a9.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_37f44922-c147-4423-999a-61166be4c5a9.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 168   
1 | _embedding_layer | Embedding1dLayer          | 2     
2 | head             | LinearHead                | 9     
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
179       Trainable params
0         Non-trainable params
179       Total params
0.001     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
`Trainer.fit` stopped: `max_epochs=200` reached.
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7fd8160a98d0>

预测与可视化

pred_df = tabular_model.predict(df_valid.sample(1000).sort_values("col1"))
pred_df.head()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`


Generating Predictions...:   0%|          | 0/1 [00:00<?, ?it/s]
col1 target target_prediction
46976 0.001606 0.310145 0.587298
13609 0.001873 0.502399 0.589231
672 0.004710 0.052382 0.609763
37622 0.005097 0.538130 0.612562
8230 0.007761 0.390704 0.631844
fig = go.Figure([
    go.Scatter(
        name='Prediction',
        x=pred_df['col1'],
        y=pred_df['target_prediction'],
        mode='lines',
        line=dict(color='rgba(28,53,94,1)'),
    ),
    go.Scatter(
        name='Actual',
        x=pred_df['col1'],
        y=pred_df['target'],
        mode='markers',
        line=dict(color='rgba(60,180,229,1)'),
    ),
])
fig.update_layout(
    yaxis_title='y',
    title='Category Embedding Network Prediction',
    hovermode="x"
)
fig.write_image("imgs/prob_reg_non_mdn_3.png")

训练MDN

定义配置

epochs = 200
batch_size = 2048
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=['col1'],
    categorical_cols=[],
# 连续特征变换="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # 运行LRFinder以自动推导学习率
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping_patience = 5,
    early_stopping="valid_loss",
    checkpoints="valid_loss",
    load_best=True
)

optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})

mdn_head_config = MixtureDensityHeadConfig(num_gaussian=2, weight_regularization=2).__dict__

backbone_config_class = "CategoryEmbeddingModelConfig"
backbone_config = dict(
    task="backbone",
    layers="128-64",  # 每一层的节点数量
    activation="ReLU",  # 各层之间的激活
    head=None,
)

model_config = MDNConfig(
    task="regression",
    backbone_config_class=backbone_config_class,
    backbone_config_params=backbone_config,
    head_config=mdn_head_config,
    learning_rate=1e-3,
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

训练模型

tabular_model.fit(train=df_train, validation=df_valid)
Global seed set to 42
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:198: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=100` reached.
Learning rate set to 0.0007585775750291836
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_653e468e-c0c6-4fa7-922a-59c28c09d328.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_653e468e-c0c6-4fa7-922a-59c28c09d328.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 8.5 K 
1 | _embedding_layer | Embedding1dLayer          | 2     
2 | _head            | MixtureDensityHead        | 388   
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
8.9 K     Trainable params
0         Non-trainable params
8.9 K     Total params
0.036     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:48: UserWarning:

Detected KeyboardInterrupt, attempting graceful shutdown...

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7fd6ad1e9f90>

预测与可视化

pred_df = tabular_model.predict(df_test, quantiles=[0.25,0.5,0.75], n_samples=100, ret_logits=True)
pred_df.head()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`


Generating Predictions...:   0%|          | 0/1 [00:00<?, ?it/s]
col1 target_prediction target_q25 target_q50 target_q75 pi_0 pi_1 sigma_0 sigma_1 mu_0 ... backbone_features_54 backbone_features_55 backbone_features_56 backbone_features_57 backbone_features_58 backbone_features_59 backbone_features_60 backbone_features_61 backbone_features_62 backbone_features_63
0 0.000000 0.643981 0.512064 0.660999 0.767942 15.547802 0.401295 0.199566 0.020728 0.651822 ... 1.046951 0.0 1.101677 1.591893 0.0 1.616410 0.0 0.0 2.758485 2.189235
1 0.001001 0.695453 0.584647 0.673015 0.819583 15.507115 0.406832 0.199811 0.020848 0.659085 ... 1.045177 0.0 1.099991 1.589368 0.0 1.613189 0.0 0.0 2.754719 2.183303
2 0.002002 0.681246 0.554822 0.703955 0.802527 15.466434 0.412371 0.200056 0.020968 0.666350 ... 1.043402 0.0 1.098305 1.586841 0.0 1.609968 0.0 0.0 2.750953 2.177371
3 0.003003 0.667499 0.537470 0.662630 0.798900 15.425751 0.417906 0.200301 0.021090 0.673611 ... 1.041627 0.0 1.096619 1.584315 0.0 1.606747 0.0 0.0 2.747187 2.171439
4 0.004004 0.658052 0.521760 0.633820 0.838818 15.385064 0.423444 0.200547 0.021212 0.680872 ... 1.039852 0.0 1.094933 1.581789 0.0 1.603526 0.0 0.0 2.743420 2.165506

5 rows × 75 columns

df = df_valid.sample(10000)
fig = go.Figure([
    go.Scatter(
        name='Ground Truth',
        x=df['col1'],
        y=df['target'],
        mode='markers',
        line=dict(color='rgba(153, 115, 142, 0.2)'),
    ),
    go.Scatter(
        name='Component 1',
        x=pred_df['col1'],
        y=pred_df['mu_0'],
        mode='lines',
        line=dict(color='rgba(90, 92, 237, 1)'),
    ),
    go.Scatter(
        name='Component 2',
        x=pred_df['col1'],
        y=pred_df['mu_1'],
        mode='lines',
        line=dict(color='rgba(246, 76, 114, 1)'),
    ),
    go.Scatter(
        name='Upper Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']+pred_df['sigma_0'],
        mode='lines',
        marker=dict(color='rgba(47, 47, 162, 0.5)'),
# 线条宽度=字典(宽度=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']-pred_df['sigma_0'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(47, 47, 162, 0.5)',
        fill='tonexty',
        showlegend=False
    ),
    go.Scatter(
        name='Upper Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']+pred_df['sigma_1'],
        mode='lines',
        marker=dict(color='rgba(250, 152, 174, 0.5)'),
# 线条宽度=字典(宽度=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']-pred_df['sigma_1'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(250, 152, 174, 0.5)',
        fill='tonexty',
        showlegend=False
    ),
])
fig.update_layout(
    yaxis_title='y',
# y轴范围=[0, 1],
    title='Mixture Density Network Prediction',
    hovermode="x",
    yaxis_range=[df['target'].min()*0.85, df['target'].max()*1.15]
)
fig.write_image("imgs/prob_reg_mdn_3.png")

fig = go.Figure([
    go.Scatter(
        name='Ground Truth',
        x=df['col1'],
        y=df['target'],
        mode='markers',
        line=dict(color='rgba(153, 115, 142, 0.2)'),
    ),
    go.Scatter(
        name='Component 1',
        x=pred_df['col1'],
        y=pred_df['mu_0'],
        mode='lines',
        line=dict(color='rgba(90, 92, 237, 1)'),
    ),
    go.Scatter(
        name='Mixing Coefficient 1',
        x=pred_df['col1'],
        y=pred_df['pi_1'],
        mode='lines',
        line=dict(color='rgba(255, 216, 117, 1)'),
    ),

    go.Scatter(
        name='Upper Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']+pred_df['sigma_0'],
        mode='lines',
        marker=dict(color='rgba(47, 47, 162, 0.5)'),
# 线条宽度=0,
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']-pred_df['sigma_0'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(47, 47, 162, 0.5)',
        fill='tonexty',
        showlegend=False
    ),

])
fig.update_layout(
    yaxis_title='y',
# y轴范围=[-0.2,1],
    title='Mixture Density Network Prediction',
    hovermode="x",
    yaxis_range=[df['target'].min()*0.85, df['target'].max()*1.15]
)
fig.write_image("imgs/prob_reg_mixing1_3.png")

fig = go.Figure([
    go.Scatter(
        name='Ground Truth',
        x=df['col1'],
        y=df['target'],
        mode='markers',
        line=dict(color='rgba(153, 115, 142, 0.2)'),
    ),

    go.Scatter(
        name='Component 2',
        x=pred_df['col1'],
        y=pred_df['mu_1'],
        mode='lines',
        line=dict(color='rgba(246, 76, 114, 1)'),
    ),

    go.Scatter(
        name='Mixing Coefficient 2',
        x=pred_df['col1'],
        y=pred_df['pi_1'],
        mode='lines',
        line=dict(color='rgba(255, 216, 117, 1)'),
    ),

    go.Scatter(
        name='Upper Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']+pred_df['sigma_1'],
        mode='lines',
        marker=dict(color='rgba(250, 152, 174, 0.5)'),
# 线宽=字典(宽度=0)
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']-pred_df['sigma_1'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(250, 152, 174, 0.5)',
        fill='tonexty',
        showlegend=False
    ),
])
fig.update_layout(
    yaxis_title='y',
# y轴范围=[-0.2, 1],
    title='Mixture Density Network Prediction',
    hovermode="x",
    yaxis_range=[df['target'].min()*0.85, df['target'].max()*1.15]
)
fig.write_image("imgs/prob_reg_mixing2_3.png")

from scipy.special import softmax
pred_df[['pi_0','pi_1']] = softmax(pred_df[['pi_0','pi_1']].values, axis=-1)
fig = px.line(pred_df, x='col1', y=['pi_0','pi_1'])
fig.write_image("imgs/prob_reg_mixing12_3.png")

波士顿住房数据集

from sklearn.datasets import fetch_california_housing
target_col = "target"
data = fetch_california_housing(return_X_y=False)
X = pd.DataFrame(data['data'], columns=data['feature_names'])
cont_cols = X.columns.tolist()
cat_cols = []
y = data['target']
X[target_col] = y
df_train, df_test = train_test_split(X, test_size=0.2, random_state=42)
df_train, df_valid = train_test_split(df_train, test_size=0.2, random_state=42)

绘图

fig = px.histogram(df_train, x="target", title="Histogram")
fig.write_image("imgs/prob_reg_hist_4.png")

训练MDN

定义配置

让我们使用该软件包中的一个方便的工具函数来找出可能的高斯成分的中心。它内部运行K均值算法,并返回聚类中心,我们可以将其设置为偏置初始化。

from pytorch_tabular.utils import get_gaussian_centers

mu_init = get_gaussian_centers(df_train[target_col], n_components=4)
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:870: FutureWarning:

The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning


epochs = 1000
batch_size = 2048
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=cont_cols,
    categorical_cols=cat_cols,
# 连续特征变换="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # 运行LRFinder以自动推导学习率
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping="valid_loss",
    early_stopping_patience=5,
    checkpoints="valid_loss",
    load_best=True
)

optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})

mdn_head_config = MixtureDensityHeadConfig(
    num_gaussian=4, 
    weight_regularization=2,
    mu_bias_init=mu_init
).__dict__

#lambda_pi=10 
#lambda_sigma=1 

backbone_config_class = "CategoryEmbeddingModelConfig"
backbone_config = dict(
    task="backbone",
    layers="200-100",  # 每层的节点数量
    activation="ReLU",  # 各层之间的激活
    head=None,
)

model_config = MDNConfig(
    task="regression",
    backbone_config_class=backbone_config_class,
    backbone_config_params=backbone_config,
    head_config=mdn_head_config,
    learning_rate=1e-3,
)


tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

训练模型

tabular_model.fit(train=df_train, validation=df_valid)
Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=100` reached.
Learning rate set to 0.0009120108393559097
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_c8d4ff90-1aad-4dcb-986b-2548e08389cc.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_c8d4ff90-1aad-4dcb-986b-2548e08389cc.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 21.9 K
1 | _embedding_layer | Embedding1dLayer          | 16    
2 | _head            | MixtureDensityHead        | 1.2 K 
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
23.1 K    Trainable params
0         Non-trainable params
23.1 K    Total params
0.092     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:48: UserWarning:

Detected KeyboardInterrupt, attempting graceful shutdown...

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7fd6a27173d0>

预测与可视化

pred_df = tabular_model.predict(df_test, quantiles=[0.25,0.5,0.75], n_samples=100, ret_logits=True)
pred_df.head()
Generating Predictions...:   0%|          | 0/3 [00:00<?, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`


MedInc HouseAge AveRooms AveBedrms Population AveOccup Latitude Longitude target target_prediction ... backbone_features_90 backbone_features_91 backbone_features_92 backbone_features_93 backbone_features_94 backbone_features_95 backbone_features_96 backbone_features_97 backbone_features_98 backbone_features_99
20046 1.6812 25.0 4.192201 1.022284 1392.0 3.877437 36.06 -119.01 0.47700 0.921853 ... 0.104538 0.0 0.516786 0.749274 0.0 0.0 0.402392 0.0 0.0 0.000000
3024 2.5313 30.0 5.039384 1.193493 1565.0 2.679795 35.14 -119.46 0.45800 1.190664 ... 0.391450 0.0 0.000000 0.395355 0.0 0.0 0.000000 0.0 0.0 0.039423
15663 3.4801 52.0 3.977155 1.185877 1310.0 1.360332 37.80 -122.44 5.00001 3.117064 ... 0.000000 0.0 0.000000 0.045079 0.0 0.0 0.000000 0.0 0.0 0.204494
20484 5.7376 17.0 6.163636 1.020202 1705.0 3.444444 34.28 -118.72 2.18600 2.724108 ... 0.000000 0.0 0.000000 0.069432 0.0 0.0 0.000000 0.0 0.0 2.369836
9814 3.7250 34.0 5.492991 1.028037 1063.0 2.483645 36.62 -121.93 2.78000 2.599437 ... 0.096121 0.0 0.156296 0.269434 0.0 0.0 0.113502 0.0 0.0 0.205405

5 rows × 125 columns

import scipy.stats as ss

def plot_normal(x_range, mu=0, sigma=1, cdf=False, **kwargs):
    '''
    Plots the normal distribution function for a given x range
    If mu and sigma are not provided, standard normal is plotted
    If cdf=True cumulative distribution is plotted
    Passes any keyword arguments to matplotlib plot function
    '''
    x = x_range
    if cdf:
        y = ss.norm.cdf(x, mu, sigma)
    else:
        y = ss.norm.pdf(x, mu, sigma)
    return x,y
import torch
from torch import nn

from torch.autograd import Variable
from torch.distributions import Categorical
def get_pdf(idx):
    row = pred_df.iloc[idx]
    pi = torch.from_numpy(row[['pi_0','pi_1','pi_2','pi_3']].values).unsqueeze(0)
    mu = torch.from_numpy(row[['mu_0','mu_1','mu_2','mu_3']].values).unsqueeze(0)
    sigma = torch.from_numpy(row[['sigma_0','sigma_1','sigma_2','sigma_3']].values).unsqueeze(0)
    softmax_pi = nn.functional.gumbel_softmax(pi, tau=1, dim=-1)
    categorical = Categorical(softmax_pi)
    pis = categorical.sample().unsqueeze(1)
    sigma = sigma.gather(1, pis).item()
    mu = mu.gather(1, pis).item()
    x = np.linspace(row['target_prediction'].item()*0.1, row['target_prediction'].item()*1.9, 5000)
    return plot_normal(x, mu=mu, sigma=sigma)
# idxs = pred_df[mask].sample(5).index

idxs = [2, 23, 564, 365]
traces = []
for idx in idxs:
    x,y = get_pdf(idx)
    trace = go.Scatter(
            name=f'House_{idx}',
            x=x,
            y=y,
            mode='lines',
    #         line=dict(color='rgba(246, 76, 114, 1)'),
        )
    traces.append(trace)

fig = go.Figure(traces)
fig.update_layout(
    yaxis_title='P(MEDV)',
    xaxis_title='MEDV',
# y轴范围=[-0.2,1],
    title='PDFs of different Houses',
    hovermode="x"
)
fig.write_image("imgs/prob_reg_pdfs_4.png")