外推技巧与窍门#

处理外推——在插值数据域之外的查询点上评估插值器——在 scipy.interpolate 中的不同例程之间并不完全一致。不同的插值器使用不同的关键字参数集来控制数据域之外的行为:有些使用 extrapolate=True/False/None,有些允许使用 fill_value 关键字。请参阅每个特定插值例程的 API 文档以获取详细信息。

根据具体问题,可用的关键字可能足够也可能不够。需要特别注意非线性插值器的外推。通常,随着与数据域的距离增加,外推结果会变得越来越不合理。这当然是意料之中的:插值器只知道数据域内的数据。

当默认的外推结果不合适时,用户需要自己实现所需的外推模式。

在本教程中,我们考虑了几个实际示例,展示了如何使用可用关键字和手动实现所需的外推模式。这些示例可能适用于也可能不适用于您的问题;它们不一定是最优实践;并且它们有意简化为仅展示主要思想所需的基本要素,希望它们能为您处理特定问题提供灵感。

interp1d : 复制 numpy.interp 的左右填充值#

简而言之:使用 fill_value=(left, right)

numpy.interp 使用常数外推,默认情况下在插值中扩展 y 数组的第一个和最后一个值 interval: np.interp(xnew, x, y) 的输出在 xnew < x[0] 时为 y[0],在 xnew > x[-1] 时为 y[-1]

默认情况下,interp1d 拒绝外推,并在评估超出插值范围的数据点时引发 ValueError。可以通过 bounds_error=False 参数关闭此行为:然后 interp1d 使用 fill_value 设置超出范围的值,默认情况下 fill_valuenan

要模仿 numpy.interp 的行为,可以使用 interp1d 支持将 fill_value 设置为 2 元组的事实。元组的元素分别用于填充 xnew < min(x)x > max(x) 的情况。对于多维 y,这些元素必须与 y 具有相同的形状,或者可以广播到 y

举例说明:

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

x = np.linspace(0, 1.5*np.pi, 11)
y = np.column_stack((np.cos(x), np.sin(x)))   # y.shape 为 (11, 2)

func = interp1d(x, y,
                axis=0,  # 沿列插值
                bounds_error=False,
                kind='linear',
                fill_value=(y[0], y[-1]))
xnew = np.linspace(-np.pi, 2.5*np.pi, 51)
ynew = func(xnew)

fix, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.plot(xnew, ynew[:, 0])
ax1.plot(x, y[:, 0], 'o')

ax2.plot(xnew, ynew[:, 1])
ax2.plot(x, y[:, 1], 'o')
plt.tight_layout()
../../_images/extrapolation_examples-1.png

CubicSpline 扩展边界条件#

CubicSpline 需要两个额外的边界条件,这些条件由 bc_type 参数控制。该参数可以列出边缘处的导数的显式值,或使用有用的别名。例如,bc_type="clamped" 将一阶导数设置为零,bc_type="natural" 将二阶导数设置为零(另外两个 已识别的字符串值为“periodic”和“not-a-knot”)

虽然外推由边界条件控制,但这种关系并不十分直观。例如,人们可能会期望对于``bc_type=”natural”``,外推是线性的。这种期望过于强烈:每个边界条件仅在边界处设置导数。外推是从第一个和最后一个多项式片段进行的,对于自然样条来说,这是一个在给定点处二阶导数为零的三次多项式。

另一种理解为什么这种期望过于强烈的方法是考虑只有三个数据点的情况,其中样条有两个多项式片段。为了线性外推,这种期望意味着这两个片段都是线性的。但是,两个线性片段不能在中点处连续匹配二阶导数!(当然,除非所有三个数据点实际上位于同一条直线上)。

为了说明这种行为,我们考虑一个合成数据集,并比较几种边界条件:

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline

xs = [1, 2, 3, 4, 5, 6, 7, 8]
ys = [4.5, 3.6, 1.6, 0.0, -3.3, -3.1, -1.8, -1.7]

notaknot = CubicSpline(xs, ys, bc_type='not-a-knot')
natural = CubicSpline(xs, ys, bc_type='natural')
clamped = CubicSpline(xs, ys, bc_type='clamped')
xnew = np.linspace(min(xs) - 4, max(xs) + 4, 101)

splines = [notaknot, natural, clamped]
titles = ['not-a-knot', 'natural', 'clamped']

fig, axs = plt.subplots(3, 3, figsize=(12, 12))
for i in [0, 1, 2]:
    for j, spline, title in zip(range(3), splines, titles):
        axs[i, j].plot(xs, spline(xs, nu=i),'o')
        axs[i, j].plot(xnew, spline(xnew, nu=i),'-')
        axs[i, j].set_title(f'{title}, deriv={i}')

plt.tight_layout()
plt.show()
../../_images/extrapolation_examples-2.png

很明显,自然样条确实具有零二阶导数,但外推行为并非线性。 边界处的导数,但外推是非线性的。 bc_type="clamped" 表现出类似的行为:一阶导数仅在边界处精确为零。在所有情况下,外推都是通过扩展样条的第一段和最后一段多项式来完成的,无论它们是什么。

强制外推的一种可能方法是扩展插值域,以添加具有所需属性的第一段和最后一段多项式。

这里我们使用 CubicSpline 超类 PPolyextend 方法,添加两个额外的断点,并确保额外的多项式段保持导数的值。然后使用这两个额外的区间进行外推。

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline

def add_boundary_knots(spline):
    """
    在左右两侧添加无穷小的节点。

    添加额外的区间以使二阶和三阶导数为零,并保持所选边界条件的一阶导数。样条在原地修改。
    """
    # 确定左侧边缘的斜率
    leftx = spline.x[0]
    lefty = spline(leftx)
    leftslope = spline(leftx, nu=1)

    # 在左侧添加一个新的断点,并使用已知的斜率构造 PPoly 系数。
    leftxnext = np.nextafter(leftx, leftx - 1)
    leftynext = lefty + leftslope*(leftxnext - leftx)
    leftcoeffs = np.array([0, 0, leftslope, leftynext])
    spline.extend(leftcoeffs[..., None], np.r_[leftxnext])

    # 在右侧重复添加额外的节点
    rightx = spline.x[-1]
    righty = spline(rightx)
    rightslope = spline(rightx,nu=1)
    rightxnext = np.nextafter(rightx, rightx + 1)
    rightynext = righty + rightslope * (rightxnext - rightx)

    rightcoeffs = np.array([0, 0, rightslope, rightynext])
    spline.extend(rightcoeffs[..., None], np.r_[rightxnext])

xs = [1, 2, 3, 4, 5, 6, 7, 8]
ys = [4.5, 3.6, 1.6, 0.0, -3.3, -3.1, -1.8, -1.7]

notaknot = CubicSpline(xs,ys, bc_type='not-a-knot')
# not-a-knot 不需要额外的区间

natural = CubicSpline(xs,ys, bc_type='natural')
# 用线性外推结点扩展自然样条
add_boundary_knots(natural)

clamped = CubicSpline(xs,ys, bc_type='clamped')
# 用常数外推结点扩展钳制样条
add_boundary_knots(clamped)

xnew = np.linspace(min(xs) - 5, max(xs) + 5, 201)

fig, axs = plt.subplots(3, 3,figsize=(12,12))

splines = [notaknot, natural, clamped]
titles = ['not-a-knot', 'natural', 'clamped']

for i in [0, 1, 2]:
    for j, spline, title in zip(range(3), splines, titles):
        axs[i, j].plot(xs, spline(xs, nu=i),'o')
        axs[i, j].plot(xnew, spline(xnew, nu=i),'-')
        axs[i, j].set_title(f'{title}, deriv={i}')

plt.tight_layout()
plt.show()



a = 3
x0 = brentq(f, 1e-16, np.pi/2, args=(a,))   # 这里我们将左边界向右移动一个机器精度,以避免在 x=0 处出现除零错误
                                            # by a machine epsilon to avoid
                                            # a division by zero at x=0
xx = np.linspace(0.2, np.pi/2, 101)
plt.plot(xx, a*xx, '--')
plt.plot(xx, 1/np.tan(xx), '--')
plt.plot(x0, a*x0, 'o', ms=12)
plt.text(0.1, 0.9, fr'$x_0 = {x0:.3f}$',
               transform=plt.gca().transAxes, fontsize=16)
plt.show()

然而,如果我们需要多次求解(例如,由于 tan 函数的周期性,需要找到一系列根),重复调用 scipy.optimize.brentq 会变得非常昂贵。

为了绕过这个困难,我们可以在表格中记录 \(y = ax - 1/\tan{x}\) 并在表格网格上进行插值。实际上,我们将使用 反向 插值:我们插值 \(x\) 相对于 \(y\) 的值。这样,求解原始方程就变成了在零 \(y\) 参数处对插值函数进行简单求值。

为了提高插值的准确性,我们将利用表格函数的导数知识。我们将使用 BPoly.from_derivatives 来构建三次插值(等效地,我们可以使用 CubicHermiteSpline

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import BPoly

def f(x, a):
    return a*x - 1/np.tan(x)

xleft, xright = 0.2, np.pi/2
x = np.linspace(xleft, xright, 11)

fig, ax = plt.subplots(1, 2, figsize=(12, 4))

for j, a in enumerate([3, 93]):
    y = f(x, a)
    dydx = a + 1./np.sin(x)**2    # d(ax - 1/tan(x)) / dx
    dxdy = 1 / dydx               # dx/dy = 1 / (dy/dx)

    xdx = np.c_[x, dxdy]
    spl = BPoly.from_derivatives(y, xdx)   # 反向插值

    yy = np.linspace(f(xleft, a), f(xright, a), 51)
    ax[j].plot(yy, spl(yy), '--')
    ax[j].plot(y, x, 'o')
    ax[j].set_xlabel(r'$y$')
    ax[j].set_ylabel(r'$x$')
    ax[j].set_title(rf'$a = {a}$')

    ax[j].plot(0, spl(0), 'o', ms=12)
    ax[j].text(0.1, 0.85, fr'$x_0 = {spl(0):.3f}$',
               transform=ax[j].transAxes, fontsize=18)
    ax[j].grid(True)
plt.tight_layout()
plt.show()
../../_images/extrapolation_examples-4.png

注意,对于 \(a=3\)spl(0) 与上面的 brentq 调用一致,而对于 \(a = 93\),差异是显著的。该过程在大 \(a\) 时开始失效的原因是直线 \(y = ax\) 趋向于垂直轴,而原方程的根趋向于 \(x=0\)。由于我们在有限网格上对原函数进行了表格化,spl(0) 对于过大的 \(a\) 值涉及外推。依赖外推容易失去精度,最好避免。

利用已知的渐近性#

观察原方程,我们注意到对于 \(x\to 0\)\(\tan(x) = x + O(x^3)\),原方程变为

\[ax = 1/x \;,\]

因此对于 \(a \gg 1\)\(x_0 \approx 1/\sqrt{a}\)

我们将利用这一点来构建一个类,该类在超出范围的数据上从插值切换到使用这种已知的渐近行为。一个简单的实现可能如下所示

class RootWithAsymptotics:
   def __init__(self, a):

       # 构建插值函数
       xleft, xright = 0.2, np.pi/2
       x = np.linspace(xleft, xright, 11)

       y = f(x, a)
       dydx = a + 1./np.sin(x)**2    # d(ax - 1/tan(x)) / dx
       dxdy = 1 / dydx               # dx/dy = 1 / (dy/dx)

       # 反向插值
       self.spl = BPoly.from_derivatives(y, np.c_[x, dxdy])
       self.a = a

   def root(self):
       out = self.spl(0)
       asympt = 1./np.sqrt(self.a)
       return np.where(spl.x.min() < asympt, out, asympt)

然后

>>> r = RootWithAsymptotics(93)
>>> r.root()
array(0.10369517)

这与外推结果不同,但与 brentq 调用一致。

请注意,此实现有意进行了简化。从 API 的角度来看,您可能希望实现 __call__ 方法,以便 xy 的完整依赖关系可用。从数值角度来看,需要更多的工作来确保插值和渐近之间的切换在渐近区域足够深入,以便在切换点处生成的函数足够平滑。

此外,在此示例中,我们人为地将问题限制为仅考虑 tan 函数的单个周期区间,并且仅处理 \(a > 0\)。对于 \(a\) 的负值,我们需要实现其他渐近性,即 \(x\to \pi\)

然而,基本思想是相同的。

“””CT 插值器 + 最近邻外推。

xyndarray, shape (npoints, ndim)

数据点的坐标

zndarray, shape (npoints)

数据点的值

funccallable

一个可调用对象,模拟 CT 行为, 并在数据范围外进行最近邻外推。

“”” x = xy[:, 0] y = xy[:, 1] f = CT(xy, z)

# 这个内部函数将返回给用户 def new_f(xx, yy):

# 评估 CT 插值器。超出范围的值为 nan。 zz = f(xx, yy) nans = np.isnan(zz)

if nans.any():

# 对于每个 nan 点,找到其最近邻 inds = np.argmin(

(x[:, None] - xx[nans])**2 + (y[:, None] - yy[nans])**2 , axis=0)

# … 并使用其值 zz[nans] = z[inds]

return zz

return new_f

# 现在在小例子上展示原始 CT 插值器和 my_CT 之间的区别:

x = np.array([1, 1, 1, 2, 2, 2, 4, 4, 4]) y = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3]) z = np.array([0, 7, 8, 3, 4, 7, 1, 3, 4])

xy = np.c_[x, y] lut = CT(xy, z) lut2 = my_CT(xy, z)

X = np.linspace(min(x) - 0.5, max(x) + 0.5, 71) Y = np.linspace(min(y) - 0.5, max(y) + 0.5, 71) X, Y = np.meshgrid(X, Y)

fig = plt.figure() ax = fig.add_subplot(projection=’3d’)

ax.plot_wireframe(X, Y, lut(X, Y), label=’CT’) ax.plot_wireframe(X, Y, lut2(X, Y), color=’m’,

cstride=10, rstride=10, alpha=0.7, label=’CT + n.n.’)

ax.scatter(x, y, z, ‘o’, color=’k’, s=48, label=’data’) ax.legend() plt.tight_layout() ```