[1]:
import math
import pathlib
from datetime import datetime, timedelta

import matplotlib.pylab as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from sklearn.linear_model import Ridge
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    mean_squared_log_error,
    median_absolute_error,
    r2_score,
)
from sklearn.model_selection import train_test_split

import mlflow

使用 MLflow 进行日志可视化

在本指南的这一部分,我们强调了 使用 MLflow 记录可视化的重要性。在训练好的模型旁边保留可视化内容可以增强模型的可解释性、审计和来源追踪,确保一个强大且透明的机器学习生命周期。

我们在做什么?

  • 存储视觉工件: 我们正在 MLflow 中记录各种图表作为视觉工件,确保它们始终可访问并与相应的模型和运行数据保持一致。

  • 增强模型可解释性: 这些可视化工具有助于理解和解释模型行为,从而提高模型的透明度和责任性。

它如何应用于MLflow?

  • 集成可视化日志记录: MLflow 无缝集成日志记录和访问可视化工件的设施,增强了处理可视化上下文和见解的便捷性和效率。

  • 便捷访问: 记录的图形可以在 MLflow UI 的 Runs 视图窗格中显示,确保分析和审查的快速便捷访问。

注意

虽然 MLflow 提供了记录可视化的简单性和便利性,但确保可视化工件与相应模型数据的一致性和相关性至关重要,以保持模型信息的完整性和全面性。

为什么一致的日志记录很重要?

  • 审计和溯源: 对可视化进行一致且全面的记录对于审计目的至关重要,确保每个模型都伴随有相关的视觉洞察,以便进行彻底的分析和审查。

  • 增强模型理解: 适当的视觉上下文增强了模型行为的理解,有助于有效的模型评估和验证。

总之,MLflow 在可视化日志记录方面的能力在确保全面、透明和高效的机器学习生命周期中发挥着不可估量的作用,增强了模型的可解释性、审计和溯源。

生成合成苹果销售数据

在下一节中,我们将深入探讨使用 generate_apple_sales_data_with_promo_adjustment 函数 生成苹果销售需求预测的合成数据。该函数模拟了与苹果销售相关的各种特征,为探索和建模提供了一个丰富的数据集。

我们在做什么?

  • 模拟真实数据: 生成一个包含日期、平均温度、降雨量、周末标志等特征的数据集,模拟苹果销售的现实场景。

  • 结合各种效果: 该函数结合了促销调整、季节性和竞争对手定价等效果,有助于实现‘需求’目标变量。

它如何应用于数据生成?

  • 综合数据集: 合成数据集提供了一套全面的功能和交互,非常适合用于探索需求预测的各个方面和维度。

  • 自由与灵活性: 合成性质允许进行无约束的探索和分析,不受现实世界数据敏感性和约束的影响。

注意

虽然合成数据在探索和学习方面提供了许多优势,但必须承认它在捕捉现实世界的复杂性和细微差别方面的局限性。

为什么承认局限性很重要?

  • 现实世界的复杂性: 合成数据可能无法捕捉到现实世界数据中所有复杂的模式和异常,这可能导致模型和见解过于简化。

  • 可迁移性到真实世界场景: 确保从合成数据中获得的见解和模型能够迁移到真实世界场景,需要仔细考虑和验证。

总之,generate_apple_sales_data_with_promo_adjustment 函数提供了一个强大的工具,用于生成全面的苹果销售需求预测综合数据集,促进了广泛的探索和分析,同时承认了综合数据的局限性。

[2]:
def generate_apple_sales_data_with_promo_adjustment(
    base_demand: int = 1000,
    n_rows: int = 5000,
    competitor_price_effect: float = -50.0,
):
    """
    Generates a synthetic dataset for predicting apple sales demand with multiple
    influencing factors.

    This function creates a pandas DataFrame with features relevant to apple sales.
    The features include date, average_temperature, rainfall, weekend flag, holiday flag,
    promotional flag, price_per_kg, competitor's price, marketing intensity, stock availability,
    and the previous day's demand. The target variable, 'demand', is generated based on a
    combination of these features with some added noise.

    Args:
        base_demand (int, optional): Base demand for apples. Defaults to 1000.
        n_rows (int, optional): Number of rows (days) of data to generate. Defaults to 5000.
        competitor_price_effect (float, optional): Effect of competitor's price being lower
                                                   on our sales. Defaults to -50.

    Returns:
        pd.DataFrame: DataFrame with features and target variable for apple sales prediction.

    Example:
        >>> df = generate_apple_sales_data_with_promo_adjustment(base_demand=1200, n_rows=6000)
        >>> df.head()
    """

    # Set seed for reproducibility
    np.random.seed(9999)

    # Create date range
    dates = [datetime.now() - timedelta(days=i) for i in range(n_rows)]
    dates.reverse()

    # Generate features
    df = pd.DataFrame(
        {
            "date": dates,
            "average_temperature": np.random.uniform(10, 35, n_rows),
            "rainfall": np.random.exponential(5, n_rows),
            "weekend": [(date.weekday() >= 5) * 1 for date in dates],
            "holiday": np.random.choice([0, 1], n_rows, p=[0.97, 0.03]),
            "price_per_kg": np.random.uniform(0.5, 3, n_rows),
            "month": [date.month for date in dates],
        }
    )

    # Introduce inflation over time (years)
    df["inflation_multiplier"] = 1 + (df["date"].dt.year - df["date"].dt.year.min()) * 0.03

    # Incorporate seasonality due to apple harvests
    df["harvest_effect"] = np.sin(2 * np.pi * (df["month"] - 3) / 12) + np.sin(
        2 * np.pi * (df["month"] - 9) / 12
    )

    # Modify the price_per_kg based on harvest effect
    df["price_per_kg"] = df["price_per_kg"] - df["harvest_effect"] * 0.5

    # Adjust promo periods to coincide with periods lagging peak harvest by 1 month
    peak_months = [4, 10]  # months following the peak availability
    df["promo"] = np.where(
        df["month"].isin(peak_months),
        1,
        np.random.choice([0, 1], n_rows, p=[0.85, 0.15]),
    )

    # Generate target variable based on features
    base_price_effect = -df["price_per_kg"] * 50
    seasonality_effect = df["harvest_effect"] * 50
    promo_effect = df["promo"] * 200

    df["demand"] = (
        base_demand
        + base_price_effect
        + seasonality_effect
        + promo_effect
        + df["weekend"] * 300
        + np.random.normal(0, 50, n_rows)
    ) * df["inflation_multiplier"]  # adding random noise

    # Add previous day's demand
    df["previous_days_demand"] = df["demand"].shift(1)
    df["previous_days_demand"].fillna(method="bfill", inplace=True)  # fill the first row

    # Introduce competitor pricing
    df["competitor_price_per_kg"] = np.random.uniform(0.5, 3, n_rows)
    df["competitor_price_effect"] = (
        df["competitor_price_per_kg"] < df["price_per_kg"]
    ) * competitor_price_effect

    # Stock availability based on past sales price (3 days lag with logarithmic decay)
    log_decay = -np.log(df["price_per_kg"].shift(3) + 1) + 2
    df["stock_available"] = np.clip(log_decay, 0.7, 1)

    # Marketing intensity based on stock availability
    # Identify where stock is above threshold
    high_stock_indices = df[df["stock_available"] > 0.95].index

    # For each high stock day, increase marketing intensity for the next week
    for idx in high_stock_indices:
        df.loc[idx : min(idx + 7, n_rows - 1), "marketing_intensity"] = np.random.uniform(0.7, 1)

    # If the marketing_intensity column already has values, this will preserve them;
    #  if not, it sets default values
    fill_values = pd.Series(np.random.uniform(0, 0.5, n_rows), index=df.index)
    df["marketing_intensity"].fillna(fill_values, inplace=True)

    # Adjust demand with new factors
    df["demand"] = df["demand"] + df["competitor_price_effect"] + df["marketing_intensity"]

    # Drop temporary columns
    df.drop(
        columns=[
            "inflation_multiplier",
            "harvest_effect",
            "month",
            "competitor_price_effect",
            "stock_available",
        ],
        inplace=True,
    )

    return df

生成苹果销售数据

在这个单元格中,我们调用 generate_apple_sales_data_with_promo_adjustment 函数来生成苹果销售数据集。

使用的参数:

  • base_demand: 设置为1000,代表苹果的基本需求。

  • n_rows: 设置为10,000,确定生成的数据集中行数或数据点的数量。

  • competitor_price_effect: 设置为 -25.0,表示当竞争对手的价格较低时对我们销售的影响。

通过运行此单元格,我们获得了一个数据集 my_data,它包含了上述配置的合成苹果销售数据。此数据集将在本笔记本的后续步骤中用于进一步的探索和分析。

你可以在生成单元格之后看到单元格中的数据。

[3]:
my_data = generate_apple_sales_data_with_promo_adjustment(
    base_demand=1000, n_rows=10_000, competitor_price_effect=-25.0
)
[4]:
my_data
[4]:
date average_temperature rainfall weekend holiday price_per_kg promo demand previous_days_demand competitor_price_per_kg marketing_intensity
0 1996-05-11 13:10:40.689999 30.584727 1.831006 1 0 1.578387 1 1301.647352 1326.324266 0.755725 0.323086
1 1996-05-12 13:10:40.689999 15.465069 0.761303 1 0 1.965125 0 1143.972638 1326.324266 0.913934 0.030371
2 1996-05-13 13:10:40.689998 10.786525 1.427338 0 0 1.497623 0 890.319248 1168.942267 2.879262 0.354226
3 1996-05-14 13:10:40.689997 23.648154 3.737435 0 0 1.952936 0 811.206168 889.965021 0.826015 0.953000
4 1996-05-15 13:10:40.689997 13.861391 5.598549 0 0 2.059993 0 822.279469 835.253168 1.130145 0.953000
... ... ... ... ... ... ... ... ... ... ... ...
9995 2023-09-22 13:10:40.682895 23.358868 7.061220 0 0 1.556829 1 1981.195884 2089.644454 0.560507 0.889971
9996 2023-09-23 13:10:40.682895 14.859048 0.868655 1 0 1.632918 0 2180.698138 2005.305913 2.460766 0.884467
9997 2023-09-24 13:10:40.682894 17.941035 13.739986 1 0 0.827723 1 2675.093671 2179.813671 1.321922 0.884467
9998 2023-09-25 13:10:40.682893 14.533862 1.610512 0 0 0.589172 0 1703.287285 2674.209204 2.604095 0.812706
9999 2023-09-26 13:10:40.682889 13.048549 5.287508 0 0 1.794122 1 1971.029266 1702.474579 1.261635 0.750458

10000 rows × 11 columns

需求的时间序列可视化

在本节中,我们正在创建一个时间序列图,以可视化需求数据及其滚动平均值。

为什么这很重要?

可视化时间序列数据对于识别模式、理解变异性以及做出更明智的决策至关重要。通过绘制滚动平均值,我们可以平滑短期波动并突出长期趋势或周期。这种视觉辅助对于理解数据和做出更准确、更明智的预测和决策是必不可少的。

代码结构:

  • 输入验证:代码首先确保数据是 pandas DataFrame。

  • 日期转换:它将‘日期’列转换为日期时间格式,以便进行准确的绘图。

  • 滚动平均计算: 它计算‘需求’的滚动平均值,使用指定的窗口大小 (window_size),默认为7天。

  • 绘图:它在同一图上绘制原始需求数据和计算的滚动平均值以进行比较。原始需求数据以低透明度绘制,以显得“幽灵般”,确保滚动平均值突出。

  • 标签和图例: 为了清晰起见,添加了适当的标签和图例。

为什么要返回一个图形?

我们返回图形对象(fig)而不是直接渲染它,这样每次模型训练事件的迭代都可以将图形作为记录的工件消耗到MLflow中。这种方法允许我们以训练时使用的数据的精确状态持久化数据可视化的状态。MLflow可以存储这个图形对象,使得在MLflow UI中轻松检索和渲染成为可能,确保可视化始终可访问并与相关的模型和数据信息配对。

[5]:
def plot_time_series_demand(data, window_size=7, style="seaborn", plot_size=(16, 12)):
    if not isinstance(data, pd.DataFrame):
        raise TypeError("df must be a pandas DataFrame.")

    df = data.copy()

    df["date"] = pd.to_datetime(df["date"])

    # Calculate the rolling average
    df["rolling_avg"] = df["demand"].rolling(window=window_size).mean()

    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=plot_size)
        # Plot the original time series data with low alpha (transparency)
        ax.plot(df["date"], df["demand"], "b-o", label="Original Demand", alpha=0.15)
        # Plot the rolling average
        ax.plot(
            df["date"],
            df["rolling_avg"],
            "r",
            label=f"{window_size}-Day Rolling Average",
        )

        # Set labels and title
        ax.set_title(
            f"Time Series Plot of Demand with {window_size} day Rolling Average",
            fontsize=14,
        )
        ax.set_xlabel("Date", fontsize=12)
        ax.set_ylabel("Demand", fontsize=12)

        # Add legend to explain the lines
        ax.legend()
        plt.tight_layout()

    plt.close(fig)
    return fig

使用箱线图可视化周末与工作日需求

在本节中,我们利用箱线图来可视化周末与工作日需求分布的差异。这种可视化有助于理解基于一周中不同日期的需求变异性和集中趋势。

为什么这很重要?

理解周末和工作日的需求差异对于做出有关库存、人员配置和其他运营方面的明智决策至关重要。它有助于识别需求高峰期,从而实现更好的资源分配和计划。

代码结构:

  • 箱线图:代码使用 Seaborn 创建了一个箱线图,展示了周末(1)和工作日(0)的需求分布。箱线图提供了中位数、四分位数以及两类需求数据中可能的异常值的洞察。

  • 添加单个数据点:为了提供更多上下文,单个数据点作为条形图叠加在箱形图上。它们为了更好地可视化而进行了抖动处理,并根据日期类型进行了颜色编码。

  • 样式: 图表的样式设计清晰,不必要的图例被移除以增强可读性。

为什么要返回一个图形?

与时间序列图一样,此函数也返回图形对象(fig),而不是直接显示它。

[6]:
def plot_box_weekend(df, style="seaborn", plot_size=(10, 8)):
    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=plot_size)
        sns.boxplot(data=df, x="weekend", y="demand", ax=ax, color="lightgray")
        sns.stripplot(
            data=df,
            x="weekend",
            y="demand",
            ax=ax,
            hue="weekend",
            palette={0: "blue", 1: "green"},
            alpha=0.15,
            jitter=0.3,
            size=5,
        )

        ax.set_title("Box Plot of Demand on Weekends vs. Weekdays", fontsize=14)
        ax.set_xlabel("Weekend (0: No, 1: Yes)", fontsize=12)
        ax.set_ylabel("Demand", fontsize=12)
        for i in ax.get_xticklabels() + ax.get_yticklabels():
            i.set_fontsize(10)
        ax.legend_.remove()
        plt.tight_layout()
    plt.close(fig)
    return fig

探究需求与每公斤价格之间的关系

在这个可视化中,我们正在创建一个散点图来研究 demandprice_per_kg 之间的关系。理解这种关系对于定价策略和需求预测至关重要。

为什么这很重要?

  • 定价策略洞察: 此可视化帮助揭示需求如何随每公斤价格变化,为设定价格以优化销售和收入提供有价值的见解。

  • 理解需求弹性: 它有助于理解价格相关的需求弹性,帮助在促销和折扣方面做出明智且数据驱动的决策。

代码结构:

  • 散点图: 代码生成一个散点图,其中每个点的位置由 price_per_kgdemand 决定,颜色表示该天是周末还是工作日。这种颜色编码有助于快速识别周末或工作日特有的模式。

  • 透明度和抖动: 点以透明度 (alpha=0.15) 绘制,以处理重叠绘图,允许可视化点的密度。

  • 回归线: 对于每个子组(周末和工作日),分别拟合并绘制一条回归线在同一坐标轴上。这些线条提供了关于每个组的需求与每公斤价格趋势的清晰视觉指示。

[7]:
def plot_scatter_demand_price(df, style="seaborn", plot_size=(10, 8)):
    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=plot_size)
        # Scatter plot with jitter, transparency, and color-coded based on weekend
        sns.scatterplot(
            data=df,
            x="price_per_kg",
            y="demand",
            hue="weekend",
            palette={0: "blue", 1: "green"},
            alpha=0.15,
            ax=ax,
        )
        # Fit a simple regression line for each subgroup
        sns.regplot(
            data=df[df["weekend"] == 0],
            x="price_per_kg",
            y="demand",
            scatter=False,
            color="blue",
            ax=ax,
        )
        sns.regplot(
            data=df[df["weekend"] == 1],
            x="price_per_kg",
            y="demand",
            scatter=False,
            color="green",
            ax=ax,
        )

        ax.set_title("Scatter Plot of Demand vs Price per kg with Regression Line", fontsize=14)
        ax.set_xlabel("Price per kg", fontsize=12)
        ax.set_ylabel("Demand", fontsize=12)
        for i in ax.get_xticklabels() + ax.get_yticklabels():
            i.set_fontsize(10)
        plt.tight_layout()
    plt.close(fig)
    return fig

可视化需求密度:工作日 vs. 周末

这种可视化让我们能够分别观察工作日和周末的 需求 分布。

为什么这很重要?

  • 需求分布洞察: 了解工作日与周末的需求分布情况可以为库存管理和人员配置提供信息。

  • 通知业务战略: 这一洞察对于制定数据驱动的决策至关重要,涉及促销、折扣和其他可能在特定日期更有效的策略。

代码结构:

  • 密度图: 该代码生成了一个 需求 的密度图,按工作日和周末分开。

  • 颜色编码的组别: 这两个组别(工作日和周末)分别用蓝色和绿色进行颜色编码,便于区分。

  • 透明度和填充: 密度曲线下的区域用浅色透明颜色 (alpha=0.15) 填充,以便于可视化同时避免视觉混乱。

什么是视觉元素?

  • 两条密度曲线: 该图包含两条密度曲线,一条代表工作日,另一条代表周末。这些曲线提供了每个群体需求分布的清晰视觉表示。

  • 图例: 添加图例以帮助识别哪条曲线对应哪个组(工作日或周末)。

[8]:
def plot_density_weekday_weekend(df, style="seaborn", plot_size=(10, 8)):
    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=plot_size)

        # Plot density for weekdays
        sns.kdeplot(
            df[df["weekend"] == 0]["demand"],
            color="blue",
            label="Weekday",
            ax=ax,
            fill=True,
            alpha=0.15,
        )

        # Plot density for weekends
        sns.kdeplot(
            df[df["weekend"] == 1]["demand"],
            color="green",
            label="Weekend",
            ax=ax,
            fill=True,
            alpha=0.15,
        )

        ax.set_title("Density Plot of Demand by Weekday/Weekend", fontsize=14)
        ax.set_xlabel("Demand", fontsize=12)
        ax.legend(fontsize=12)
        for i in ax.get_xticklabels() + ax.get_yticklabels():
            i.set_fontsize(10)

        plt.tight_layout()
    plt.close(fig)
    return fig

模型系数的可视化

在本节中,我们利用条形图来可视化训练模型中特征的系数。

为什么这很重要?

理解系数的量级和方向对于解释模型至关重要。它有助于识别对预测影响最大的特征。这一洞察对于特征选择、工程设计以及最终提高模型性能至关重要。

代码结构:

  • 上下文设置:代码首先将绘图样式设置为‘seaborn’,以增强美观性。

  • 图初始化: 它创建一个用于绘图的图形和轴。

  • 条形图:它使用水平条形图(barh)来可视化每个特征的系数。y轴表示特征名称,x轴表示系数值。这种可视化方式使得比较系数变得容易,从而提供了关于它们相对重要性和对目标变量影响的洞察。

  • 标题和标签:它设置了一个适当的标题(“系数图”)以及 x 轴(“系数值”)和 y 轴(“特征”)的标签,以确保清晰度和可理解性。

通过可视化系数,我们可以更深入地理解模型,使其更容易解释模型的预测,并做出更明智的关于特征重要性和影响的决策。

[9]:
def plot_coefficients(model, feature_names, style="seaborn", plot_size=(10, 8)):
    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=plot_size)
        ax.barh(feature_names, model.coef_)
        ax.set_title("Coefficient Plot", fontsize=14)
        ax.set_xlabel("Coefficient Value", fontsize=12)
        ax.set_ylabel("Features", fontsize=12)
        plt.tight_layout()
    plt.close(fig)
    return fig

残差的可视化

在本节中,我们正在创建一个图表来可视化模型的残差,这些残差是观测值与预测值之间的差异。

为什么这很重要?

残差图是回归分析中的一个基本诊断工具,用于研究预测变量和响应变量之间关系的不可预测性。它有助于识别非线性、异方差性和异常值。该图有助于验证误差呈正态分布且具有恒定方差的假设,这对回归模型预测的可靠性至关重要。

代码结构:

  • 残差计算: 代码首先计算残差,即实际值 (y_test) 和预测值 (y_pred) 之间的差异。

  • 上下文设置:代码将绘图样式设置为‘seaborn’,以获得视觉上吸引人的绘图。

  • 图初始化: 它创建一个用于绘图的图形和轴。

  • 残差图:它利用 Seaborn 的 residplot 来创建残差图,并使用低通(局部加权散点平滑)线来突出残差中的趋势。

  • 零线: 它在零处添加一条虚线,作为观察残差的参考。线上方的残差表示预测不足,而线下方的残差表示预测过度。

  • 标题和标签:它设置了一个适当的标题(“残差图”)以及 x 轴(“预测值”)和 y 轴(“残差”)的标签,以确保清晰度和可理解性。

通过检查残差图,我们可以更好地判断模型的充分性以及是否需要进一步的改进或增加复杂性。

[10]:
def plot_residuals(y_test, y_pred, style="seaborn", plot_size=(10, 8)):
    residuals = y_test - y_pred

    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=plot_size)
        sns.residplot(
            x=y_pred,
            y=residuals,
            lowess=True,
            ax=ax,
            line_kws={"color": "red", "lw": 1},
        )

        ax.axhline(y=0, color="black", linestyle="--")
        ax.set_title("Residual Plot", fontsize=14)
        ax.set_xlabel("Predicted values", fontsize=12)
        ax.set_ylabel("Residuals", fontsize=12)

        for label in ax.get_xticklabels() + ax.get_yticklabels():
            label.set_fontsize(10)

        plt.tight_layout()

    plt.close(fig)
    return fig

预测误差的可视化

在本节中,我们正在创建一个图表来可视化预测误差,展示我们模型中实际值和预测值之间的差异。

为什么这很重要?

理解预测误差对于评估模型性能至关重要。预测误差图提供了误差分布的洞察,并有助于识别趋势、偏差或异常值。这种可视化是模型评估的关键组成部分,有助于识别模型可能需要改进的领域,并确保其对新数据的泛化能力。

代码结构:

  • 上下文设置:代码将绘图样式设置为 ‘seaborn’,以获得干净且吸引人的图表。

  • 图形初始化: 它初始化一个用于绘图的图形和坐标轴。

  • 散点图:代码绘制了预测值与误差(实际值 - 预测值)的关系图。图中的每个点代表一个特定的观测值,其在y轴上的位置表示误差的大小和方向(低于零表示预测不足,高于零表示预测过度)。

  • 零线: 在 y=0 处绘制了一条红色虚线作为参考,有助于轻松识别误差。位于此线上方的点表示欠预测,而位于此线下方的点表示过预测。

  • 标题和标签:它添加了一个标题(“预测误差图”)以及 x 轴(“预测值”)和 y 轴(“误差”)的标签,以提高清晰度和理解。

通过分析预测误差图,从业者可以获得关于模型性能的宝贵见解,有助于进一步优化和增强模型,以实现更好、更可靠的预测。

[11]:
def plot_prediction_error(y_test, y_pred, style="seaborn", plot_size=(10, 8)):
    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=plot_size)
        ax.scatter(y_pred, y_test - y_pred)
        ax.axhline(y=0, color="red", linestyle="--")
        ax.set_title("Prediction Error Plot", fontsize=14)
        ax.set_xlabel("Predicted Values", fontsize=12)
        ax.set_ylabel("Errors", fontsize=12)
        plt.tight_layout()
    plt.close(fig)
    return fig

分位数-分位数图 (QQ Plot) 的可视化

在本节中,我们将生成一个 QQ 图来可视化模型预测残差的分布。

为什么这很重要?

QQ 图对于评估模型残差是否遵循正态分布至关重要,这是线性回归模型的基本假设。如果 QQ 图中的点没有紧密跟随直线并显示出某种模式,这表明残差可能不是正态分布的,这可能意味着模型存在问题,如异方差性或非线性。

代码结构:

  • 残差计算:代码首先通过从实际测试值中减去预测值来计算残差。

  • 上下文设置:绘图样式设置为‘seaborn’以增强美观性。

  • 图形初始化: 为绘图初始化一个图形和轴。

  • QQ 图生成:使用 stats.probplot 函数生成 QQ 图。它绘制残差的分位数与正态分布的分位数。

  • 标题添加: 为了清晰起见,图中添加了一个标题(“QQ图”)。

通过仔细分析 QQ 图,我们可以确保模型的残差符合正态性假设。如果不符合,探索其他模型类型或变换可能有助于提高模型的性能和可靠性。

[12]:
def plot_qq(y_test, y_pred, style="seaborn", plot_size=(10, 8)):
    residuals = y_test - y_pred
    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=plot_size)
        stats.probplot(residuals, dist="norm", plot=ax)
        ax.set_title("QQ Plot", fontsize=14)
        plt.tight_layout()
    plt.close(fig)
    return fig

特征相关矩阵

在本节中,我们生成一个 特征相关矩阵 来可视化数据集中不同特征之间的关系。

注意: 与本笔记本中的其他图表不同,我们将图表的本地副本保存到磁盘,以展示任意文件的替代日志记录机制,即 log_artifact() API。在下面的主要模型训练和日志记录部分中,您将看到如何将此图表添加到 MLflow 运行中。

为什么这很重要?

理解不同特征之间的相关性对于以下方面至关重要: - 识别多重共线性,这可能影响模型性能和可解释性。 - 深入了解变量之间的关系,这可以指导特征工程和选择。 - 揭示不同特征之间潜在的因果关系或相互作用,这可以增进领域理解和进一步分析。

代码结构:

  • 相关性计算:代码首先计算所提供DataFrame的相关性矩阵。

  • 遮罩:为相关矩阵的上三角创建一个遮罩,因为矩阵是对称的,我们不需要可视化重复的信息。

  • 热图生成: 生成热图以可视化相关系数。颜色渐变和注释提供了变量之间关系的清晰见解。

  • 标题添加:为图表添加标题以便清晰识别。

通过分析相关矩阵,我们可以做出更明智的特征选择决策,并更好地理解数据集内部的关系。

[13]:
def plot_correlation_matrix_and_save(
    df, style="seaborn", plot_size=(10, 8), path="/tmp/corr_plot.png"
):
    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=plot_size)

        # Calculate the correlation matrix
        corr = df.corr()

        # Generate a mask for the upper triangle
        mask = np.triu(np.ones_like(corr, dtype=bool))

        # Draw the heatmap with the mask and correct aspect ratio
        sns.heatmap(
            corr,
            mask=mask,
            cmap="coolwarm",
            vmax=0.3,
            center=0,
            square=True,
            linewidths=0.5,
            annot=True,
            fmt=".2f",
        )

        ax.set_title("Feature Correlation Matrix", fontsize=14)
        plt.tight_layout()

    plt.close(fig)
    # convert to filesystem path spec for os compatibility
    save_path = pathlib.Path(path)
    fig.savefig(save_path)

模型训练与可视化主要执行的详细概述

本节深入探讨了模型训练、预测、误差计算和可视化的全面工作流程。每个步骤的重要性以及特定选择的原因都进行了详细讨论。

结构化执行的好处

以结构化的方式执行模型训练和评估的所有关键步骤是基础性的。它提供了一个框架,确保建模过程的每个方面都被考虑,从而提供一个更可靠和健壮的模型。这种简化的执行有助于避免遗漏的错误或偏见,并确保模型在所有必要的方面得到评估。

日志可视化对 MLflow 的重要性

将日志可视化到 MLflow 提供了几个关键优势:

  • 持久性:与笔记本的短暂状态不同,笔记本中的单元格可以不按顺序运行,从而可能导致潜在的误解,将绘图记录到 MLflow 可以确保可视化与特定运行一起永久存储。这种持久性确保了模型训练和评估的视觉上下文得以保留,消除了混淆,并确保了解释的清晰性。

  • 出处: 通过记录可视化,捕获了模型训练时数据的精确状态和关系。对于很久以前训练的模型,这一做法至关重要。它提供了一个可靠的参考点,以理解模型在训练时的行为和数据特征,确保随着时间的推移,见解和解释仍然有效和可靠。

  • 可访问性:将可视化存储在 MLflow 中使得所有团队成员或相关利益方都能轻松访问。这种可视化的集中存储增强了协作,使不同团队成员能够轻松查看、分析和解释可视化内容,从而促进更明智和集体的决策。

代码的详细结构:

  1. 设置 MLflow:

    • MLflow 的跟踪 URI 已定义。

    • 一个名为“可视化演示”的实验已经设置,所有运行和日志将存储在此实验下。

  2. 数据准备

    • Xy 分别被定义为特征和目标变量。

    • 数据集被分为训练集和测试集,以确保模型的性能在未见过的数据上进行评估。

  3. 初始图生成:

    • 生成了包括时间序列、箱线图、散点图和密度图在内的初始图表。

    • 这些图表提供了对数据及其特征的初步洞察。

  4. 模型定义与训练

    • 一个岭回归模型以 alpha 值为 1.0 定义。

    • 模型在训练数据上进行训练,学习数据中的关系和模式。

  5. 预测与误差计算:

    • 训练好的模型用于对测试数据进行预测。

    • 计算了包括MSE、RMSE、MAE、R2、MSLE和MedAE在内的各种误差指标,以评估模型的性能。

  6. 附加图表生成

    • 生成了包括残差图、系数图、预测误差图和QQ图在内的其他图表。

    • 这些图表进一步揭示了模型的性能、残差行为以及误差的分布。

  7. 记录到 MLflow:

    • 训练好的模型、计算的指标、定义的参数(alpha)以及所有生成的图表都会记录到 MLflow 中。

    • 此日志记录确保与模型相关的所有信息和可视化内容都存储在一个集中且可访问的位置。

结论:

通过执行这个全面且结构化的代码,我们确保了模型训练、评估和解释的每个方面都得到了覆盖。将所有相关信息和可视化记录到 MLflow 的做法进一步增强了模型的可靠性、可访问性和可解释性,有助于更明智和可靠的模型部署和使用。

[14]:
mlflow.set_tracking_uri("http://127.0.0.1:8080")

mlflow.set_experiment("Visualizations Demo")

X = my_data.drop(columns=["demand", "date"])
y = my_data["demand"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

fig1 = plot_time_series_demand(my_data, window_size=28)
fig2 = plot_box_weekend(my_data)
fig3 = plot_scatter_demand_price(my_data)
fig4 = plot_density_weekday_weekend(my_data)

# Execute the correlation plot, saving the plot to a local temporary directory
plot_correlation_matrix_and_save(my_data)

# Define our Ridge model
model = Ridge(alpha=1.0)

# Train the model
model.fit(X_train, y_train)

# Make predictions
y_pred = model.predict(X_test)

# Calculate error metrics
mse = mean_squared_error(y_test, y_pred)
rmse = math.sqrt(mse)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
msle = mean_squared_log_error(y_test, y_pred)
medae = median_absolute_error(y_test, y_pred)

# Generate prediction-dependent plots
fig5 = plot_residuals(y_test, y_pred)
fig6 = plot_coefficients(model, X_test.columns)
fig7 = plot_prediction_error(y_test, y_pred)
fig8 = plot_qq(y_test, y_pred)

# Start an MLflow run for logging metrics, parameters, the model, and our figures
with mlflow.start_run() as run:
    # Log the model
    mlflow.sklearn.log_model(sk_model=model, input_example=X_test, artifact_path="model")

    # Log the metrics
    mlflow.log_metrics(
        {"mse": mse, "rmse": rmse, "mae": mae, "r2": r2, "msle": msle, "medae": medae}
    )

    # Log the hyperparameter
    mlflow.log_param("alpha", 1.0)

    # Log plots
    mlflow.log_figure(fig1, "time_series_demand.png")
    mlflow.log_figure(fig2, "box_weekend.png")
    mlflow.log_figure(fig3, "scatter_demand_price.png")
    mlflow.log_figure(fig4, "density_weekday_weekend.png")
    mlflow.log_figure(fig5, "residuals_plot.png")
    mlflow.log_figure(fig6, "coefficients_plot.png")
    mlflow.log_figure(fig7, "prediction_errors.png")
    mlflow.log_figure(fig8, "qq_plot.png")

    # Log the saved correlation matrix plot by referring to the local file system location
    mlflow.log_artifact("/tmp/corr_plot.png")
2023/09/26 13:10:41 INFO mlflow.tracking.fluent: Experiment with name 'Visualizations Demo' does not exist. Creating a new experiment.
/Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/mlflow/models/signature.py:333: UserWarning: Hint: Inferred schema contains integer column(s). Integer columns in Python cannot represent missing values. If your input data contains missing values at inference time, it will be encoded as floats and will cause a schema enforcement error. The best way to avoid this problem is to infer the model schema based on a realistic data sample (training dataset) that includes missing values. Alternatively, you can declare integer columns as doubles (float64) whenever these columns may have missing values. See `Handling Integers With Missing Values <https://www.mlflow.org/docs/latest/models.html#handling-integers-with-missing-values>`_ for more details.
  input_schema = _infer_schema(input_ex)
/Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils.
  warnings.warn("Setuptools is replacing distutils.")