pymc.gp.HSGP.prior_linearized#

HSGP.prior_linearized(Xs)[源代码]#

HSGP 的线性化版本。返回拉普拉斯特征函数和创建 GP 所需功率谱密度平方根。

此函数允许用户绕过GP接口,直接与基函数和系数进行工作。这种格式允许用户使用 pm.set_data 创建预测,类似于线性模型。它还使得在多GP模型中实现计算速度的提升,因为它们可能共享相同的基函数。返回值是拉普拉斯特征函数 phi,以及功率谱密度平方根。

在使用 prior_linearizedpm.set_datapm.MutableData 时,要得到正确的结果需要满足两个条件。首先,在构建 GP 时必须指定 L 而不是 c。否则,会引发 RuntimeError。其次,Xs 需要是零中心化的,因此必须减去其均值。下面给出一个示例。

参数:
Xs: array-like

函数输入值。假设它们已经进行了均值减法或以零为中心。

返回:
phi: array_like

Numpy 或 PyTensor 的二维固定基向量数组。有 n 行,对应 Xs 的每一行,以及 prod(m) 列,对应每个基向量。

sqrt_psd: array_like

一个 Numpy 或 PyTensor 1D 数组,包含功率谱密度的平方根。

示例

# A one dimensional column vector of inputs.
X = np.linspace(0, 10, 100)[:, None]

with pm.Model() as model:
    eta = pm.Exponential("eta", lam=1.0)
    ell = pm.InverseGamma("ell", mu=5.0, sigma=5.0)
    cov_func = eta**2 * pm.gp.cov.ExpQuad(1, ls=ell)

    # m = [200] means 200 basis vectors for the first dimenison
    # L = [10] means the approximation is valid from Xs = [-10, 10]
    gp = pm.gp.HSGP(m=[200], L=[10], cov_func=cov_func)

    # Order is important.  First calculate the mean, then make X a shared variable,
    # then subtract the mean.  When X is mutated later, the correct mean will be
    # subtracted.
    X_mean = np.mean(X, axis=0)
    X = pm.MutableData("X", X)
    Xs = X - X_mean

    # Pass the zero-subtracted Xs in to the GP
    phi, sqrt_psd = gp.prior_linearized(Xs=Xs)

    # Specify standard normal prior in the coefficients.  The number of which
    # is given by the number of basis vectors, which is also saved in the GP object
    # as m_star.
    beta = pm.Normal("beta", size=gp.m_star)

    # The (non-centered) GP approximation is given by
    f = pm.Deterministic("f", phi @ (beta * sqrt_psd))

    ...

# Then it works just like a linear regression to predict on new data.
# First mutate the data X,
x_new = np.linspace(-10, 10, 100)
with model:
    model.set_data("X", x_new[:, None])

# and then make predictions for the GP using posterior predictive sampling.
with model:
    ppc = pm.sample_posterior_predictive(idata, var_names=["f"])