编写自定义函数

本指南将描述如何在 SymPy 中创建自定义函数类。自定义用户定义函数使用与 SymPy 自带的 函数 相同的机制,例如常见的 初等函数exp()sin()特殊函数gamma()Si(),以及 组合函数数论函数factorial()primepi()。因此,本指南既适用于希望定义自己自定义函数的使用者,也适用于希望扩展 SymPy 自带函数的 SymPy 开发者。

本指南描述了如何定义复值函数,即映射 \(\mathbb{C}^n\) 子集到 \(\mathbb{C}\) 的函数。接受或返回复数以外类型对象的函数应继承其他类,如 BooleanMatrixExprExprBasic。这里所写的内容中,有些适用于一般的 BasicExpr 子类,但大部分内容仅适用于 Function 子类。

简单情况:完全符号化或完全求值

在深入探讨自定义函数的更高级功能之前,我们应该提到两种常见情况:函数完全符号化的情况和函数完全求值的情况。这两种情况都有比本指南中描述的完整机制简单得多的替代方案。

完全符号化的情况

如果你的函数 f 没有任何你想定义的数学属性,并且永远不应该在任何参数上求值,你可以使用 Function('f') 创建一个未定义的函数。

>>> from sympy import symbols, Function
>>> x = symbols('x')
>>> f = Function('f')
>>> f(x)
f(x)
>>> f(0)
f(0)

这在解决 ODEs 时很有用。

如果你只想创建一个依赖于另一个符号的符号以用于微分,这也是有用的。默认情况下,SymPy 假设所有符号彼此独立:

>>> from sympy.abc import x, y
>>> y.diff(x)
0

要创建一个依赖于另一个符号的符号,可以使用一个显式依赖于该符号的函数。

>>> y = Function('y')
>>> y(x).diff(x)
Derivative(y(x), x)

如果你想让你的函数具有额外的行为,例如,拥有自定义的导数,或在某些参数上进行求值,你应该创建一个自定义的 Function 子类,如 下文所述。然而,未定义的函数确实支持一个额外的功能,即可以使用与符号相同的语法来定义假设。这定义了函数输出的假设,而不是输入(也就是说,它定义了函数的值域,而不是定义域)。

>>> g = Function('g', real=True)
>>> g(x)
g(x)
>>> g(x).is_real
True

要使函数的假设以某种方式依赖于其输入,您应该创建一个自定义的 Function 子类,并按照 如下所述 定义假设处理程序。

完全评估的情况

在另一端是那些无论输入是什么,总是会求值的函数。这些函数永远不会以未求值的符号形式,如 f(x) ,存在。

在这种情况下,你应该使用带有 def 关键字的普通 Python 函数:

>>> def f(x):
...     if x == 0:
...         return 0
...     else:
...         return x + 1
>>> f(0)
0
>>> f(1)
2
>>> f(x)
x + 1

如果你发现自己在一个 Function 子类上定义了一个 eval() 方法,并且你总是返回一个值而从不返回 None,你应该考虑改用普通的 Python 函数,因为在这种情况下使用符号 Function 子类没有任何好处(参见下面的 eval() 的最佳实践 部分)

需要注意的是,在许多情况下,类似这样的函数可以直接使用 SymPy 类来表示。例如,上述函数可以使用 Piecewise 进行符号化表示。Piecewise 表达式可以通过 subs()x 的特定值进行求值。

>>> from sympy import Piecewise, Eq, pprint
>>> f = Piecewise((0, Eq(x, 0)), (x + 1, True))
>>> pprint(f, use_unicode=True)
⎧  0    for x = 0

⎩x + 1  otherwise
>>> f.subs(x, 0)
0
>>> f.subs(x, 1)
2

完全符号化的表示形式,如 Piecewise,其优势在于它们能准确地表示符号值。例如,在上述 Python def 定义的 f 中,f(x) 隐含地假设 x 是非零的。而 Piecewise 版本则能正确处理这种情况,并且除非已知 x 不为零,否则不会评估为 \(x eq 0\) 的情况。

另一个选项,如果你想得到一个不仅评估,而且总是评估为数值的函数,可以使用 lambdify()。这将把一个 SymPy 表达式转换为一个可以使用 NumPy 评估的函数。

>>> from sympy import lambdify
>>> func = lambdify(x, Piecewise((0, Eq(x, 0)), (x + 1, True)))
>>> import numpy as np 
>>> func(np.arange(5)) 
array([0., 2., 3., 4., 5.])

最终,适合工作的正确工具取决于你在做什么以及你想要的确切行为。

创建自定义函数

创建自定义函数的第一步是继承 Function。子类的名称将是函数的名称。然后应根据您希望提供哪些功能,在此子类上定义不同的方法。

作为本文的一个激励示例,让我们创建一个表示 versine 函数 的自定义函数类。Versine 是一个三角函数,历史上与正弦和余弦等更熟悉的三角函数一起使用。如今它很少使用。Versine 可以通过以下恒等式定义

\[\operatorname{versin}(x) = 1 - \cos(x).\]

SymPy 不包含 versine,因为它在现代数学中使用得非常少,而且它很容易用更熟悉的余弦来定义。

让我们从子类化 Function 开始。

>>> class versin(Function):
...     pass

在这一点上,versin 没有定义任何行为。它与我们上面讨论的 未定义函数 非常相似。注意 versin 是一个类,而 versin(x) 是这个类的一个实例。

>>> versin(x)
versin(x)
>>> isinstance(versin(x), versin)
True

备注

下面描述的所有方法都是可选的。如果你想定义给定的行为,可以包含它们,但如果省略它们,SymPy 将默认保持未评估状态。例如,如果你没有定义 微分diff() 将只会返回一个未评估的 Derivative

使用 eval() 定义自动评估

我们可能希望在自定义函数上定义的第一件也是最常见的事情是自动求值,即在哪些情况下它会返回一个实际值,而不是保持原样未求值。

这是通过定义类方法 eval() 来完成的。eval() 应该接受函数的参数并返回一个值或 None。如果返回 None,函数在这种情况下将保持未评估状态。这也用于定义函数的签名(默认情况下,如果没有 eval() 方法,Function 子类将接受任意数量的参数)。

对于我们的函数 versin,我们可能会回想起对于整数 \(n\),有 \(\cos(n\pi) = (-1)^n\),因此 \(\operatorname{versin}(n\pi) = 1 - (-1)^n\)。我们可以让 versin 在传入 pi 的整数倍时自动计算为此值:

>>> from sympy import pi, Integer
>>> class versin(Function):
...    @classmethod
...    def eval(cls, x):
...        # If x is an integer multiple of pi, x/pi will cancel and be an Integer
...        n = x/pi
...        if isinstance(n, Integer):
...            return 1 - (-1)**n
>>> versin(pi)
2
>>> versin(2*pi)
0

这里我们利用了这样一个事实:如果一个Python函数没有显式返回值,它会自动返回None。所以在if isinstance(n, Integer)语句未被触发的情况下,eval()返回None,而versin保持未求值状态。

>>> versin(x*pi)
versin(pi*x)

备注

Function 子类不应重新定义 __new____init__。如果你想实现 eval() 无法实现的行为,可能更合理的是继承 Expr 而不是 Function

eval() 可以接受任意数量的参数,包括带有 *args 的任意数量参数和可选的关键字参数。函数的 .args 将始终是用户传入的参数。例如

>>> class f(Function):
...     @classmethod
...     def eval(cls, x, y=1, *args):
...         return None
>>> f(1).args
(1,)
>>> f(1, 2).args
(1, 2)
>>> f(1, 2, 3).args
(1, 2, 3)

最后,请注意,一旦 evalf() 被定义,浮点输入的自动评估会自动发生,因此您不需要在 eval() 中显式处理它。

eval() 的最佳实践

在定义 eval() 方法时,某些反模式很常见,应避免使用。

  • 不要仅仅返回一个表达式。

    在上面的例子中,我们可能会有写下来的冲动。

    >>> from sympy import cos
    >>> class versin(Function):
    ...     @classmethod
    ...     def eval(cls, x):
    ...         # !! Not actually a good eval() method !!
    ...         return 1 - cos(x)
    

    然而,这将使得 versin(x)总是 返回 1 - cos(x),无论 x 是什么。如果你想要的只是一个快速的简写来表示 1 - cos(x),那很好,但会更简单和更明确地直接 使用如上所述的Python函数。如果我们这样定义 versin,它实际上永远不会被表示为 versin(x),我们下面定义的其他行为也将无关紧要,因为我们将在 versin 类上定义的其他行为仅在返回的对象实际上是 versin 实例时才适用。因此,例如,versin(x).diff(x) 实际上只是 (1 - cos(x)).diff(x),而不是调用 我们下面定义的 fdiff() 方法

    关键点

    eval() 的目的不是定义函数 是什么 ,从数学上来说,而是指定它应该在哪些输入上自动求值。 函数的数学定义是通过以下方法指定的各种数学属性来确定的,例如 数值求值微分 等。

    如果你发现自己正在这样做,你应该考虑你实际上想要实现什么。如果你只是想要一个表达式的简写函数,直接定义一个 Python 函数会更简单。如果你确实想要一个符号函数,考虑一下你希望它在什么时候评估为其他值,以及什么时候保持未评估状态。一个选项是在 eval() 中使你的函数保持未评估状态,并定义一个 doit() 方法来评估它。

  • 避免过多的自动评估。

    建议最小化 eval() 自动评估的内容。通常将更高级的简化放在 其他方法 中,例如 doit()。记住,你为自动评估定义的任何内容都将 始终 评估。[1] 如前一点所述,如果你评估每个值,那么首先就没有必要有一个符号函数。例如,我们可能会想在 eval() 中对 versin 评估一些三角恒等式,但这些恒等式将始终评估,并且无法表示恒等式的一半。

    也应该避免在 eval() 中执行任何计算缓慢的操作。SymPy 通常假设创建表达式是廉价的,如果事实并非如此,可能会导致性能问题。

    最后,建议避免基于假设在 eval() 中执行自动评估。相反,eval() 通常只应评估显式的数值特殊值,并对其他所有内容返回 None。你可能在上面的例子中注意到,我们使用了 isinstance(n, Integer) 而不是使用假设系统检查 n.is_integer。我们本可以这样做,这将使 versin(n*pi)n = Symbol('n', integer=True) 时也能评估。但这是一种我们可能并不总是希望评估发生的情况,而且如果 n 是一个更复杂的表达式,n.is_integer 的计算成本可能会更高。

    我们来看一个例子。利用恒等式 \(\cos(x + y) = \cos(x)\cos(y) - \sin(x)\sin(y)\),我们可以推导出该恒等式。

    \[::\]

    假设我们决定在 eval() 中自动扩展这个:

    >>> from sympy import Add, sin
    >>> class versin(Function):
    ...     @classmethod
    ...     def eval(cls, x):
    ...         # !! Not actually a good eval() method !!
    ...         if isinstance(x, Add):
    ...             a, b = x.as_two_terms()
    ...             return (versin(a)*versin(b) - versin(a) - versin(b)
    ...                     - sin(a)*sin(b) + 1)
    

    此方法递归地将 Add 项分成两部分,并应用上述恒等式。

    >>> x, y, z = symbols('x y z')
    >>> versin(x + y)
    -sin(x)*sin(y) + versin(x)*versin(y) - versin(x) - versin(y) + 1
    

    但现在无法在不扩展的情况下表示 versin(x + y)。这也会影响其他方法。例如,假设我们定义了 微分(见下文)

    >>> class versin(Function):
    ...     @classmethod
    ...     def eval(cls, x):
    ...         # !! Not actually a good eval() method !!
    ...         if isinstance(x, Add):
    ...             a, b = x.as_two_terms()
    ...             return (versin(a)*versin(b) - versin(a) - versin(b)
    ...                     - sin(a)*sin(b) + 1)
    ...
    ...     def fdiff(self, argindex=1):
    ...         return sin(self.args[0])
    

    我们期望 versin(x + y).diff(x) 返回 sin(x + y),实际上,如果我们在 eval() 中没有展开这个恒等式,它就会。但在这个版本中,versin(x + y) 在调用 diff() 之前会自动展开,相反,我们得到的是一个更复杂的表达式:

    >>> versin(x + y).diff(x)
    sin(x)*versin(y) - sin(x) - sin(y)*cos(x)
    

    情况甚至比那更糟。让我们尝试一个带有三个项的 Add

    >>> versin(x + y + z)
    (-sin(y)*sin(z) + versin(y)*versin(z) - versin(y) - versin(z) +
    1)*versin(x) - sin(x)*sin(y + z) + sin(y)*sin(z) - versin(x) -
    versin(y)*versin(z) + versin(y) + versin(z)
    

    我们可以看到,事情很快就失控了。事实上,versin(Add(*symbols('x:100')))(在包含100项的Add上使用versin())需要超过一秒的时间来评估,而这仅仅是创建表达式,甚至还没有对其进行任何操作。

    像这样的身份最好避免在 eval 中使用,而是通过其他方法实现(在这个例子中,使用 expand_trig())。

  • 当限制输入域时:允许 None 输入假设。

    我们的示例函数 \(\operatorname{versin}(x)\) 是一个从 \(\mathbb{C}\)\(\mathbb{C}\) 的函数,因此它可以接受任何输入。但假设我们有一个函数,它只在某些输入下才有意义。作为第二个例子,让我们定义一个函数 divides

    \[::\]

    也就是说,如果 m 能整除 n,则 divides(m, n) 将是 1,否则为 0。显然,divides 只有在 mn 是整数时才有意义。

    我们可能会倾向于这样为 divides 定义 eval() 方法:

    >>> class divides(Function):
    ...     @classmethod
    ...     def eval(cls, m, n):
    ...         # !! Not actually a good eval() method !!
    ...
    ...         # Evaluate for explicit integer m and n. This part is fine.
    ...         if isinstance(m, Integer) and isinstance(n, Integer):
    ...             return int(n % m == 0)
    ...
    ...         # For symbolic arguments, require m and n to be integer.
    ...         # If we write the logic this way, we will run into trouble.
    ...         if not m.is_integer or not n.is_integer:
    ...             raise TypeError("m and n should be integers")
    

    这里的问题是,通过使用 if not m.is_integer,我们要求 m.is_integer 必须是 True。如果它是 None,它将失败(详见布尔值和三值逻辑指南,了解假设为 None 的含义)。这有两个问题。首先,它迫使用户对任何输入变量定义假设。如果用户省略了它们,它将失败:

    >>> n, m = symbols('n m')
    >>> print(n.is_integer)
    None
    >>> divides(m, n)
    Traceback (most recent call last):
    ...
    TypeError: m and n should be integers
    

    相反,他们必须写

    >>> n, m = symbols('n m', integer=True)
    >>> divides(m, n)
    divides(m, n)
    

    这可能看起来是一个可以接受的限制,但有一个更大的问题。有时,SymPy的假设系统无法推导出一个假设,即使它在数学上是正确的。在这种情况下,它会返回 None (在SymPy的假设中,None 意味着“未定义”和“无法计算”)。例如

    >>> # n and m are still defined as integer=True as above
    >>> divides(2, (m**2 + m)/2)
    Traceback (most recent call last):
    ...
    TypeError: m and n should be integers
    

    这里表达式 (m**2 + m)/2 总是一个整数,但SymPy的假设系统无法推导出这一点:

    >>> print(((m**2 + m)/2).is_integer)
    None
    

    SymPy 的假设系统一直在改进,但总会有这样的情况它无法推导,这是由于问题的基本计算复杂性,以及一般问题 通常 不可判定 的事实。

    因此,对于输入变量的否定假设,应始终进行测试,即,如果假设为 False 则失败,但允许假设为 None

    >>> class divides(Function):
    ...     @classmethod
    ...     def eval(cls, m, n):
    ...         # Evaluate for explicit integer m and n. This part is fine.
    ...         if isinstance(m, Integer) and isinstance(n, Integer):
    ...             return int(n % m == 0)
    ...
    ...         # For symbolic arguments, require m and n to be integer.
    ...         # This is the better way to write this logic.
    ...         if m.is_integer is False or n.is_integer is False:
    ...             raise TypeError("m and n should be integers")
    

    这仍然禁止非整数输入,如预期:

    >>> divides(1.5, 1)
    Traceback (most recent call last):
    ...
    TypeError: m and n should be integers
    

    但在假设为 None 的情况下,它不会失败:

    >>> divides(2, (m**2 + m)/2)
    divides(2, m**2/2 + m/2)
    >>> _.subs(m, 2)
    0
    >>> n, m = symbols('n m') # Redefine n and m without the integer assumption
    >>> divides(m, n)
    divides(m, n)
    

    备注

    这条允许 None 假设的规则仅适用于会引发异常的情况下,例如在检查输入域的类型时。在执行简化或其他操作的情况下,应将 None 假设视为“可以是 TrueFalse”,并且不执行可能不具有数学有效性的操作。

假设

接下来你可能想要定义的是我们函数的前提条件。前提系统允许定义在给定输入的情况下你的函数具有哪些数学属性,例如,“当 \(x\)实数 时,\(f(x)\)正数。”

关于假设系统指南深入探讨了假设系统。建议首先阅读该指南,以了解不同假设的含义以及假设系统的工作原理。

最简单的情况是一个函数无论输入如何,总是有给定的假设。在这种情况下,您可以直接在类上定义 is_assumption

例如,我们的 示例 divides 函数 总是一个整数,因为它的值总是 0 或 1:

>>> class divides(Function):
...     is_integer = True
...     is_negative = False
>>> divides(m, n).is_integer
True
>>> divides(m, n).is_nonnegative
True

然而,一般来说,一个函数的假设取决于其输入的假设。在这种情况下,你应该定义一个 evalassumption 方法。

对于我们的\(\operatorname{versin}(x)\) 示例,当 \(x\) 为实数时,该函数的值总是介于 \([0, 2]\) 之间,并且当 \(x\)\(\pi\) 的偶数倍时,函数值恰好为 0。因此,versin(x)x 为实数时应为 非负,并且在 x 为实数且不是 \(\pi\) 的偶数倍时应为 。请记住,默认情况下,函数的定义域是整个 \(\mathbb{C}\),并且 versin(x) 对于非实数 x 也完全有意义。

要检查 x 是否是 pi 的偶数倍,我们可以使用 as_independent() 来结构化地匹配 xcoeff*pi。像这样在假设处理程序中结构化地拆分子表达式比使用类似 (x/pi).is_even 的方法更好,因为后者会创建一个新的表达式 x/pi。创建新表达式的速度要慢得多。此外,每当创建一个表达式时,创建表达式时调用的构造函数通常会查询假设。如果不小心,这可能会导致无限递归。因此,假设处理程序的一个好的一般规则是,永远不要在假设处理程序中创建新表达式。始终使用 as_independent 这样的结构化方法来拆分函数的参数。

注意,对于非实数 \(x\)\(\operatorname{versin}(x)\) 可以是 非负的,例如:

>>> from sympy import I
>>> 1 - cos(pi + I*pi)
1 + cosh(pi)
>>> (1 - cos(pi + I*pi)).evalf()
12.5919532755215

因此,对于 _eval_is_nonnegative 处理程序,如果 x.is_realTrue,我们希望返回 True,但如果 x.is_realFalseNone,则返回 None。读者可以自行练习处理非实数 x 的情况,使得 versin(x) 为非负数,使用类似于 _eval_is_positive 处理程序的逻辑。

在假设处理程序方法中,如同所有方法一样,我们可以使用 self.args 访问函数的参数。

>>> from sympy.core.logic import fuzzy_and, fuzzy_not
>>> class versin(Function):
...     def _eval_is_nonnegative(self):
...         # versin(x) is nonnegative if x is real
...         x = self.args[0]
...         if x.is_real is True:
...             return True
...
...     def _eval_is_positive(self):
...         # versin(x) is positive if x is real and not an even multiple of pi
...         x = self.args[0]
...
...         # x.as_independent(pi, as_Add=False) will split x as a Mul of the
...         # form coeff*pi
...         coeff, pi_ = x.as_independent(pi, as_Add=False)
...         # If pi_ = pi, x = coeff*pi. Otherwise x is not (structurally) of
...         # the form coeff*pi.
...         if pi_ == pi:
...             return fuzzy_and([x.is_real, fuzzy_not(coeff.is_even)])
...         elif x.is_real is False:
...             return False
...         # else: return None. We do not know for sure whether x is an even
...         # multiple of pi
>>> versin(1).is_nonnegative
True
>>> versin(2*pi).is_positive
False
>>> versin(3*pi).is_positive
True

注意在更复杂的 _eval_is_positive() 处理程序中使用了 fuzzy_ 函数,以及对 if/elif 的仔细处理。在使用假设时,始终小心处理 正确处理三值逻辑 非常重要。这确保了当 x.is_realcoeff.is_evenNone 时,该方法返回正确的答案。

警告

切勿将 <code class="docutils literal notranslate"><span class="pre">is_*assumption*</span></code> 定义为 @property 方法。这样做会破坏其他假设的自动推导。<code class="docutils literal notranslate"><span class="pre">is_*assumption*</span></code> 应仅定义为等于 TrueFalse 的类变量。如果假设依赖于函数的 .args,请定义 <code class="docutils literal notranslate"><span class="pre">\_eval\_*assumption*</span></code> 方法。

在这个例子中,不需要定义 _eval_is_real() ,因为它可以从其他假设中自动推导出来,因为 nonnegative -> real 。一般来说,你应该避免定义假设系统可以自动推导的假设,给定其 已知事实

>>> versin(1).is_real
True

假设系统通常能够推导出比你想象的更多的内容。例如,从上述内容中,它可以推导出当 n 是整数时,versin(2*n*pi) 为零。

>>> n = symbols('n', integer=True)
>>> versin(2*n*pi).is_zero
True

在手动编码之前,总是值得检查假设系统是否可以自动推导出某些内容。

最后,提醒一句:在编码假设时要非常小心正确性。确保使用各种假设的精确 定义,并始终检查你是否正确处理了带有模糊三值逻辑函数的 None 情况。不正确或不一致的假设可能导致细微的错误。建议在函数具有非平凡假设处理程序时,使用单元测试检查所有各种情况。SymPy 本身定义的所有函数都需要进行广泛的测试。

使用 evalf() 进行数值评估

这里我们展示如何定义一个函数应如何数值评估为一个浮点 Float 值,例如,通过 evalf()。实现数值评估在 SymPy 中启用了几种行为。例如,一旦定义了 evalf(),你就可以绘制你的函数,并且类似不等式的东西可以评估为显式值。

如果你的函数与 mpmath 中的某个函数同名,这在 SymPy 包含的大多数函数中都是如此,数值计算将会自动进行,你不需要做任何事情。

如果不是这种情况,可以通过定义方法 _eval_evalf(self, prec) 来指定数值评估,其中 prec 是输入的二进制精度。该方法应返回按给定精度评估的表达式,如果无法评估则返回 None

备注

_eval_evalf()prec 参数是 二进制 精度,即浮点表示中的位数。这与 evalf() 方法的第一个参数不同,后者是 十进制 精度,或 dps。例如,Float 的默认二进制精度是 53,对应于十进制精度 15。因此,如果你的 _eval_evalf() 方法递归调用另一个表达式的 evalf,它应该调用 expr._eval_evalf(prec) 而不是 expr.evalf(prec),因为后者会错误地将 prec 用作十进制精度。

我们可以通过递归计算 \(2\sin^2\left(\frac{x}{2}\right)\) 来定义我们的示例 \(\operatorname{versin}(x)\) 函数的数值评估,这是书写 \(1 - \cos(x)\) 的一种更数值稳定的方式。

>>> from sympy import sin
>>> class versin(Function):
...     def _eval_evalf(self, prec):
...         return (2*sin(self.args[0]/2)**2)._eval_evalf(prec)
>>> versin(1).evalf()
0.459697694131860

一旦定义了 _eval_evalf(),这将启用浮点输入的自动评估。不需要在 eval() 中手动实现这一点。

>>> versin(1.)
0.459697694131860

请注意,evalf() 可以传递任何表达式,而不仅仅是那些可以数值计算的表达式。在这种情况下,预期表达式的数值部分将被计算。遵循的一般模式是对函数的参数递归调用 _eval_evalf(prec)

如果可能,最好重用现有 SymPy 函数中定义的 evalf 功能。然而,在某些情况下,直接使用 mpmath 将是必要的。

重写与简化

各种简化函数和方法允许在自定义子类上指定它们的行为。SymPy 中的每个函数并不都有这样的钩子。详情请参阅每个函数的文档。

rewrite()

The rewrite() 方法允许根据特定的函数或规则重写表达式。例如,

>>> sin(x).rewrite(cos)
cos(x - pi/2)

要实现重写,定义一个方法 _eval_rewrite(self, rule, args, **hints),其中

  • rule 是传递给 rewrite() 方法的 规则 。通常 rule 将是被重写的对象的类,尽管对于更复杂的重写,它可以是任何东西。每个定义了 _eval_rewrite() 的对象都定义了它支持的规则。许多 SymPy 函数重写为常见的类,如 expr.rewrite(Add),以执行简化或其他计算。

  • args 是用于重写的函数的参数。应该使用 args 而不是 self.args,因为在 args 中的任何递归表达式都会被重写(假设调用者使用了 rewrite(deep=True),这是默认设置)。

  • **hints 是额外的关键字参数,可用于指定重写的行为。未知的提示应被忽略,因为它们可能会传递给其他 _eval_rewrite() 方法。如果你递归调用重写,你应该传递 **hints

该方法应返回一个重写的表达式,使用 args 作为函数的参数,如果表达式应保持不变,则返回 None

对于我们的 versin 示例,我们可以实现的一个明显的重写是将 versin(x) 重写为 1 - cos(x)

>>> class versin(Function):
...     def _eval_rewrite(self, rule, args, **hints):
...         if rule == cos:
...             return 1 - cos(*args)
>>> versin(x).rewrite(cos)
1 - cos(x)

一旦我们定义了这个,simplify() 现在能够简化一些包含 versin 的表达式:

>>> from sympy import simplify
>>> simplify(versin(x) + cos(x))
1

doit()

The doit() 方法用于评估“未评估”的函数。要定义 doit(),请实现 doit(self, deep=True, **hints)。如果 deep=Truedoit() 应该递归调用 doit() 在参数上。**hints 将是传递给用户的任何其他关键字参数,这些参数应传递给任何递归调用 doit()。您可以使用 hints 来允许用户为 doit() 指定特定行为。

在自定义 Function 子类中 doit() 的典型用法是执行 eval() 中未执行的更高级评估。

例如,对于我们的 divides 示例,有几个实例可以使用一些恒等式来简化。例如,我们定义了 eval() 来在显式整数上进行评估,但我们也可能希望评估像 divides(k, k*n) 这样的示例,其中可分性在符号上是正确的。eval() 的最佳实践之一是避免过多的自动评估。在这种情况下自动评估可能被认为是过多的,因为它会使用假设系统,这可能会很昂贵。此外,我们可能希望能够表示 divides(k, k*n) 而不总是对其进行评估。

解决方案是在 doit() 中实现这些更高级的评估。这样,我们可以通过调用 expr.doit() 来显式地执行它们,但它们不会默认发生。一个执行此简化的 dividesdoit() 示例(以及 上述 eval() 的定义)可能如下所示:

备注

如果 doit() 返回一个 Python int 字面量,将其转换为 Integer,以便返回的对象是 SymPy 类型。

>>> from sympy import Integer
>>> class divides(Function):
...     # Define evaluation on basic inputs, as well as type checking that the
...     # inputs are not nonintegral.
...     @classmethod
...     def eval(cls, m, n):
...         # Evaluate for explicit integer m and n.
...         if isinstance(m, Integer) and isinstance(n, Integer):
...             return int(n % m == 0)
...
...         # For symbolic arguments, require m and n to be integer.
...         if m.is_integer is False or n.is_integer is False:
...             raise TypeError("m and n should be integers")
...
...     # Define doit() as further evaluation on symbolic arguments using
...     # assumptions.
...     def doit(self, deep=False, **hints):
...         m, n = self.args
...         # Recursively call doit() on the args whenever deep=True.
...         # Be sure to pass deep=True and **hints through here.
...         if deep:
...            m, n = m.doit(deep=deep, **hints), n.doit(deep=deep, **hints)
...
...         # divides(m, n) is 1 iff n/m is an integer. Note that m and n are
...         # already assumed to be integers because of the logic in eval().
...         isint = (n/m).is_integer
...         if isint is True:
...             return Integer(1)
...         elif isint is False:
...             return Integer(0)
...         else:
...             return divides(m, n)

(注意,这里使用了约定,即对于所有\(k\),有\(k \mid 0\),因此我们不需要检查mn是否为零。如果使用不同的约定,我们需要在执行简化之前检查m.is_zeron.is_zero。)

>>> n, m, k = symbols('n m k', integer=True)
>>> divides(k, k*n)
divides(k, k*n)
>>> divides(k, k*n).doit()
1

另一种常见的 doit() 实现方式是让它总是返回另一个表达式。这实际上将函数视为另一个表达式的“未求值”形式。

例如,让我们定义一个用于融合乘加的函数:\(\operatorname{FMA}(x, y, z) = xy + z\)。将此函数表达为一个独立的函数可能是有用的,例如,用于代码生成,但在某些情况下,将FMA(x, y, z)“求值”为x*y + z也是有用的,以便它可以与其他表达式正确简化。

>>> from sympy import Number
>>> class FMA(Function):
...     """
...     FMA(x, y, z) = x*y + z
...     """
...     @classmethod
...     def eval(cls, x, y, z):
...         # Number is the base class of Integer, Rational, and Float
...         if all(isinstance(i, Number) for i in [x, y, z]):
...            return x*y + z
...
...     def doit(self, deep=True, **hints):
...         x, y, z = self.args
...         # Recursively call doit() on the args whenever deep=True.
...         # Be sure to pass deep=True and **hints through here.
...         if deep:
...             x = x.doit(deep=deep, **hints)
...             y = y.doit(deep=deep, **hints)
...             z = z.doit(deep=deep, **hints)
...         return x*y + z
>>> x, y, z = symbols('x y z')
>>> FMA(x, y, z)
FMA(x, y, z)
>>> FMA(x, y, z).doit()
x*y + z

大多数自定义函数不会希望以这种方式定义 doit()。然而,这可以在始终求值的函数和从不求值的函数之间提供一个折中方案,生成一个默认不求值但可以根据需要求值的函数(参见上面的讨论)。

expand()

expand() 函数以多种方式“展开”一个表达式。它实际上是围绕几个子展开提示的包装器。每个函数对应于 expand() 函数/方法的一个提示。可以通过定义 _eval_expand_hint(self, **hints) 来在自定义函数中定义特定的展开 提示。有关定义了哪些提示以及每个特定 expand_hint() 函数的文档(例如,expand_trig())的详细信息,请参阅 expand() 的文档。

**hints 关键字参数是可能传递给扩展函数的额外提示,用于指定额外的行为(这些提示与前一段中描述的预定义 hints 是分开的)。未知的提示应被忽略,因为它们可能适用于其他函数的自定义 expand() 方法。一个常见的提示是 force,其中 force=True 会强制进行可能对所有给定输入假设不具有数学有效性的扩展。例如,expand_log(log(x*y), force=True) 产生 log(x) + log(y),即使这个恒等式对所有复数 xy 并不成立(通常 force=False 是默认值)。

注意,expand() 方法会自动使用其自身的 deep 标志递归地扩展表达式,因此 _eval_expand_* 方法不应递归调用扩展函数的参数。

对于我们的 versin 示例,我们可以通过定义一个 _eval_expand_trig 方法来定义基本的 trig 扩展,该方法递归地调用 expand_trig()1 - cos(x) 上:

>>> from sympy import expand_trig
>>> y = symbols('y')
>>> class versin(Function):
...    def _eval_expand_trig(self, **hints):
...        x = self.args[0]
...        return expand_trig(1 - cos(x))
>>> versin(x + y).expand(trig=True)
sin(x)*sin(y) - cos(x)*cos(y) + 1

一个更复杂的实现可能会尝试将 expand_trig(1 - cos(x)) 的结果重写回 versin 函数。这留给读者作为练习。

微分

要通过 diff() 定义微分,请定义一个方法 fdiff(self, argindex)fdiff() 应返回函数对第 argindex 个变量的导数,不考虑链式法则。argindex1 开始索引。

也就是说,f(x1, ..., xi, ..., xn).fdiff(i) 应该返回 \(\frac{d}{d x_i} f(x_1, \ldots, x_i, \ldots, x_n)\),其中 \(x_k\) 相互独立。diff() 会自动使用 fdiff() 的结果应用链式法则。用户代码应使用 diff(),而不是直接调用 fdiff()

备注

Function 子类应使用 fdiff() 定义微分。Expr 的子类如果不是 Function 子类,则需要定义 _eval_derivative() 。不建议在 Function 子类上重新定义 _eval_derivative()

对于我们的\(\operatorname{versin}\) 示例函数,其导数为 \(\sin(x)\)

>>> class versin(Function):
...     def fdiff(self, argindex=1):
...         # argindex indexes the args, starting at 1
...         return sin(self.args[0])
>>> versin(x).diff(x)
sin(x)
>>> versin(x**2).diff(x)
2*x*sin(x**2)
>>> versin(x + y).diff(x)
sin(x + y)

作为一个具有多个参数的函数的示例,考虑上面定义的融合乘加(FMA)示例\(\operatorname{FMA}(x, y, z) = xy + z\))。

我们有

\[\frac{d}{dx} \operatorname{FMA}(x, y, z) = y,\]
\[\frac{d}{dy} \operatorname{FMA}(x, y, z) = x,\]
\[\frac{d}{dz} \operatorname{FMA}(x, y, z) = 1.\]

因此,FMAfdiff() 方法看起来像这样:

>>> from sympy import Number, symbols
>>> x, y, z = symbols('x y z')
>>> class FMA(Function):
...     """
...     FMA(x, y, z) = x*y + z
...     """
...     def fdiff(self, argindex):
...         # argindex indexes the args, starting at 1
...         x, y, z = self.args
...         if argindex == 1:
...             return y
...         elif argindex == 2:
...             return x
...         elif argindex == 3:
...             return 1
>>> FMA(x, y, z).diff(x)
y
>>> FMA(x, y, z).diff(y)
x
>>> FMA(x, y, z).diff(z)
1
>>> FMA(x**2, x + 1, y).diff(x)
x**2 + 2*x*(x + 1)

要保留导数未求值,请引发 sympy.core.function.ArgumentIndexError(self, argindex)。如果未定义 fdiff(),这是默认行为。以下是一个函数 \(f(x, y)\) 的示例,该函数在第一个参数中是线性的,并且在第二个参数上具有未求值的导数。

>>> from sympy.core.function import ArgumentIndexError
>>> class f(Function):
...    @classmethod
...    def eval(cls, x, y):
...        pass
...
...    def fdiff(self, argindex):
...        if argindex == 1:
...           return 1
...        raise ArgumentIndexError(self, argindex)
>>> f(x, y).diff(x)
1
>>> f(x, y).diff(y)
Derivative(f(x, y), y)

打印

你可以通过各种 打印机 来定义一个函数如何打印自身,例如 字符串打印机漂亮打印机LaTeX 打印机,以及各种语言的代码打印机,如 CFortran

在大多数情况下,您不需要定义任何打印方法。默认行为是使用函数名称进行打印。然而,在某些情况下,我们可能希望为函数定义特殊的打印方式。

例如,对于我们上面的 divides 示例,我们可能希望 LaTeX 打印机打印一个更数学化的表达式。让我们让 LaTeX 打印机将 divides(m, n) 表示为 \left [ m \middle | n \right ],看起来像 \(\left [ m \middle | n \right ]\)(这里 \([P]\)Iverson 括号,如果 \(P\) 为真则为 \(1\),如果 \(P\) 为假则为 \(0\))。

定义SymPy对象的打印方式主要有两种。一种是在打印类上定义一个打印器。大多数属于SymPy库的类应该使用这种方法,通过在 sympy.printing 中的相应类上定义打印器。对于用户代码,如果你在定义一个自定义打印器,或者如果你有很多自定义函数需要定义打印方式,这种方法可能是首选。参见 自定义打印机的示例 以了解如何以这种方式定义打印器的示例。

另一种方法是将打印定义为函数类的方法。为此,首先查找您要为其定义打印的打印机的 printmethod 属性。这是您应为该打印机定义的方法的名称。对于 LaTeX 打印机,LatexPrinter.printmethod'_latex'。打印方法总是接受一个参数,printer。应使用 printer._print 来递归打印任何其他表达式,包括函数的参数。

因此,为了定义我们的 divides LaTeX 打印机,我们将在类上定义函数 _latex(self, printer),如下所示:

>>> from sympy import latex
>>> class divides(Function):
...     def _latex(self, printer):
...         m, n = self.args
...         _m, _n = printer._print(m), printer._print(n)
...         return r'\left [ %s \middle | %s \right ]' % (_m, _n)
>>> print(latex(divides(m, n)))
\left [ m \middle | n \right ]

关于如何定义打印机方法以及需要避免的一些陷阱的更多细节,请参见 自定义打印方法的示例。最重要的是,你应该始终使用 printer._print() 来递归打印自定义打印机内部函数的参数。

其他方法

可以在自定义函数上定义几种其他方法来指定各种行为。

inverse()

inverse(self, argindex=1) 方法可以定义来指定函数的逆。这被 solve()solveset() 使用。argindex 参数是函数的参数,从1开始(类似于 fdiff() 方法 中相同参数名称)。

inverse() 应该返回一个函数(而不是表达式)作为逆函数。如果逆函数是一个比单个函数更大的表达式,它可以返回一个 lambda 函数。

inverse() 应该仅定义为一一对应的函数。换句话说,f(x).inverse()f(x)左逆。在非一一对应的函数上定义 inverse() 可能会导致 solve() 无法给出包含该函数的表达式的所有可能解。

我们的 示例 versine 函数 不是一对一的(因为余弦函数不是),但其反函数 \(\operatorname{arcversin}\) 是。我们可以如下定义它(使用与 SymPy 中其他反三角函数相同的命名约定):

>>> class aversin(Function):
...     def inverse(self, argindex=1):
...         return versin

这使得 solve() 可以用于 aversin(x)

>>> from sympy import solve
>>> solve(aversin(x) - y, x)
[versin(y)]

as_real_imag()

方法 as_real_imag() 定义了如何将一个函数分解为其实部和虚部。它被各种单独操作表达式实部和虚部的 SymPy 函数所使用。

as_real_imag(self, deep=True, **hints) 应返回一个包含函数实部和虚部的2元组。即 expr.as_real_imag() 返回 (re(expr), im(expr)),其中 expr == re(expr) + im(expr)*I,且 re(expr)im(expr) 为实数。

如果 deep=True,它应该递归地调用其参数上的 as_real_imag(deep=True, **hints)。与 doit() _eval_expand_*() 方法 一样,**hints 可以是任何提示,允许用户指定方法的行为。未知的提示应被忽略,并在任何递归调用中传递,以防它们是为其他 as_real_imag() 方法准备的。

对于我们的versin 示例,我们可以递归地使用已经为 1 - cos(x) 定义的 as_real_imag()

>>> class versin(Function):
...     def as_real_imag(self, deep=True, **hints):
...         return (1 - cos(self.args[0])).as_real_imag(deep=deep, **hints)
>>> versin(x).as_real_imag()
(-cos(re(x))*cosh(im(x)) + 1, sin(re(x))*sinh(im(x)))

定义 as_real_imag() 也会自动使 expand_complex() 生效。

>>> versin(x).expand(complex=True)
I*sin(re(x))*sinh(im(x)) - cos(re(x))*cosh(im(x)) + 1

杂项 _eval_* 方法

在 SymPy 中有很多其他函数,它们的行为可以通过自定义的 _eval_* 方法在自定义函数上定义,类似于上面描述的方法。有关如何定义每个方法的详细信息,请参阅特定函数的文档。

完整示例

以下是本指南中定义的示例函数的完整示例。有关每个方法的详细信息,请参见上面的章节。

正矢

versine(余矢)函数定义为

\[\operatorname{versin}(x) = 1 - \cos(x).\]

Versine 是为所有复数定义的一个简单函数的例子。其数学定义简单,这使得在其上定义所有上述方法变得直接(在大多数情况下,我们可以直接复用 SymPy 在 1 - cos(x) 上定义的现有逻辑)。

定义

>>> from sympy import Function, cos, expand_trig, Integer, pi, sin
>>> from sympy.core.logic import fuzzy_and, fuzzy_not
>>> class versin(Function):
...     r"""
...     The versine function.
...
...     $\operatorname{versin}(x) = 1 - \cos(x) = 2\sin(x/2)^2.$
...
...     Geometrically, given a standard right triangle with angle x in the
...     unit circle, the versine of x is the positive horizontal distance from
...     the right angle of the triangle to the rightmost point on the unit
...     circle. It was historically used as a more numerically accurate way to
...     compute 1 - cos(x), but it is rarely used today.
...
...     References
...     ==========
...
...     .. [1] https://en.wikipedia.org/wiki/Versine
...     .. [2] https://blogs.scientificamerican.com/roots-of-unity/10-secret-trig-functions-your-math-teachers-never-taught-you/
...     """
...     # Define evaluation on basic inputs.
...     @classmethod
...     def eval(cls, x):
...         # If x is an explicit integer multiple of pi, x/pi will cancel and
...         # be an Integer.
...         n = x/pi
...         if isinstance(n, Integer):
...             return 1 - (-1)**n
...
...     # Define numerical evaluation with evalf().
...     def _eval_evalf(self, prec):
...         return (2*sin(self.args[0]/2)**2)._eval_evalf(prec)
...
...     # Define basic assumptions.
...     def _eval_is_nonnegative(self):
...         # versin(x) is nonnegative if x is real
...         x = self.args[0]
...         if x.is_real is True:
...             return True
...
...     def _eval_is_positive(self):
...         # versin(x) is positive if x is real and not an even multiple of pi
...         x = self.args[0]
...
...         # x.as_independent(pi, as_Add=False) will split x as a Mul of the
...         # form n*pi
...         coeff, pi_ = x.as_independent(pi, as_Add=False)
...         # If pi_ = pi, x = coeff*pi. Otherwise pi_ = 1 and x is not
...         # (structurally) of the form n*pi.
...         if pi_ == pi:
...             return fuzzy_and([x.is_real, fuzzy_not(coeff.is_even)])
...         elif x.is_real is False:
...             return False
...         # else: return None. We do not know for sure whether x is an even
...         # multiple of pi
...
...     # Define the behavior for various simplification and rewriting
...     # functions.
...     def _eval_rewrite(self, rule, args, **hints):
...         if rule == cos:
...             return 1 - cos(*args)
...         elif rule == sin:
...             return 2*sin(x/2)**2
...
...     def _eval_expand_trig(self, **hints):
...         x = self.args[0]
...         return expand_trig(1 - cos(x))
...
...     def as_real_imag(self, deep=True, **hints):
...         # reuse _eval_rewrite(cos) defined above
...         return self.rewrite(cos).as_real_imag(deep=deep, **hints)
...
...     # Define differentiation.
...     def fdiff(self, argindex=1):
...         return sin(self.args[0])

示例

评估:

>>> x, y = symbols('x y')
>>> versin(x)
versin(x)
>>> versin(2*pi)
0
>>> versin(1.0)
0.459697694131860

假设:

>>> n = symbols('n', integer=True)
>>> versin(n).is_real
True
>>> versin((2*n + 1)*pi).is_positive
True
>>> versin(2*n*pi).is_zero
True
>>> print(versin(n*pi).is_positive)
None
>>> r = symbols('r', real=True)
>>> print(versin(r).is_positive)
None
>>> nr = symbols('nr', real=False)
>>> print(versin(nr).is_nonnegative)
None

简化:

>>> a, b = symbols('a b', real=True)
>>> from sympy import I
>>> versin(x).rewrite(cos)
1 - cos(x)
>>> versin(x).rewrite(sin)
2*sin(x/2)**2
>>> versin(2*x).expand(trig=True)
2 - 2*cos(x)**2
>>> versin(a + b*I).expand(complex=True)
I*sin(a)*sinh(b) - cos(a)*cosh(b) + 1

微分:

>>> versin(x).diff(x)
sin(x)

解决:

aversin 的更通用版本将定义上述所有方法)

>>> class aversin(Function):
...     def inverse(self, argindex=1):
...         return versin
>>> from sympy import solve
>>> solve(aversin(x**2) - y, x)
[-sqrt(versin(y)), sqrt(versin(y))]

划分

divides 是一个由定义的函数

\[::\]

即,divides(m, n)m 能整除 n 时为 1,在 m 不能整除 n 时为 0。它仅对整数 mn 定义。为了简化,我们采用约定,所有整数 m 都能整除 0,即 \(m \mid 0\)

divides 是一个仅针对某些输入值(整数)定义的函数的示例。divides 还展示了如何定义自定义打印机(_latex())。

定义

>>> from sympy import Function, Integer
>>> from sympy.core.logic import fuzzy_not
>>> class divides(Function):
...     r"""
...     $$\operatorname{divides}(m, n) = \begin{cases} 1 & \text{for}\: m \mid n \\ 0 & \text{for}\: m\not\mid n  \end{cases}.$$
...
...     That is, ``divides(m, n)`` is ``1`` if ``m`` divides ``n`` and ``0``
...     if ``m`` does not divide ``n`. It is undefined if ``m`` or ``n`` are
...     not integers. For simplicity, the convention is used that
...     ``divides(m, 0) = 1`` for all integers ``m``.
...
...     References
...     ==========
...
...     .. [1] https://en.wikipedia.org/wiki/Divisor#Definition
...     """
...     # Define evaluation on basic inputs, as well as type checking that the
...     # inputs are not nonintegral.
...     @classmethod
...     def eval(cls, m, n):
...         # Evaluate for explicit integer m and n.
...         if isinstance(m, Integer) and isinstance(n, Integer):
...             return int(n % m == 0)
...
...         # For symbolic arguments, require m and n to be integer.
...         if m.is_integer is False or n.is_integer is False:
...             raise TypeError("m and n should be integers")
...
...     # Define basic assumptions.
...
...     # divides is always either 0 or 1.
...     is_integer = True
...     is_negative = False
...
...     # Whether divides(m, n) is 0 or 1 depends on m and n. Note that this
...     # method only makes sense because we don't automatically evaluate on
...     # such cases, but instead simplify these cases in doit() below.
...     def _eval_is_zero(self):
...         m, n = self.args
...         if m.is_integer and n.is_integer:
...              return fuzzy_not((n/m).is_integer)
...
...     # Define doit() as further evaluation on symbolic arguments using
...     # assumptions.
...     def doit(self, deep=False, **hints):
...         m, n = self.args
...         # Recursively call doit() on the args whenever deep=True.
...         # Be sure to pass deep=True and **hints through here.
...         if deep:
...            m, n = m.doit(deep=deep, **hints), n.doit(deep=deep, **hints)
...
...         # divides(m, n) is 1 iff n/m is an integer. Note that m and n are
...         # already assumed to be integers because of the logic in eval().
...         isint = (n/m).is_integer
...         if isint is True:
...             return Integer(1)
...         elif isint is False:
...             return Integer(0)
...         else:
...             return divides(m, n)
...
...     # Define LaTeX printing for use with the latex() function and the
...     # Jupyter notebook.
...     def _latex(self, printer):
...         m, n = self.args
...         _m, _n = printer._print(m), printer._print(n)
...         return r'\left [ %s \middle | %s \right ]' % (_m, _n)
...

示例

评估

>>> from sympy import symbols
>>> n, m, k = symbols('n m k', integer=True)
>>> divides(3, 10)
0
>>> divides(3, 12)
1
>>> divides(m, n).is_integer
True
>>> divides(k, 2*k)
divides(k, 2*k)
>>> divides(k, 2*k).is_zero
False
>>> divides(k, 2*k).doit()
1

打印:

>>> str(divides(m, n)) # This is using the default str printer
'divides(m, n)'
>>> print(latex(divides(m, n)))
\left [ m \middle | n \right ]

融合乘加 (FMA)

融合乘加 (FMA) 是先进行乘法然后进行加法的操作:

\[\operatorname{FMA}(x, y, z) = xy + z.\]

它通常在硬件中作为一个单独的浮点运算来实现,这种运算在舍入和性能方面优于等效的乘法和加法运算的组合。

FMA 是一个自定义函数的示例,它被定义为另一个函数的未求值“简写”。这是因为 doit() 方法被定义为返回 x*y + z,这意味着 FMA 函数可以轻松地求值为其所表示的表达式,但 eval() 方法 返回任何内容(除非 xyz 都是显式的数值),这意味着它默认保持未求值状态。

与此对比,versine 示例将 versin 视为一个独立的头等函数。尽管 versin(x) 可以用其他函数(1 - cos(x))表示,但它在 versin.eval() 中不会对一般的符号输入进行求值,并且 versin.doit() 根本未定义。

FMA 也是一个在多变量上定义的连续函数示例,它展示了在 fdiff 示例中 argindex 的工作原理。

最后,FMA 展示了一个为 CC++ 定义一些代码打印机的示例(使用来自 C99CodePrinter.printmethodCXX11CodePrinter.printmethod 的方法名称),因为这是该函数的典型用例。

FMA 的数学定义非常简单,定义其上的每种方法都很容易,但这里只展示了几种。versinedivides 示例展示了如何定义本指南中讨论的其他重要方法。

请注意,如果你想在代码生成中实际使用融合乘加,SymPy 中已经有一个版本 sympy.codegen.cfunctions.fma(),它被现有的代码打印机支持。这里的版本仅设计为示例用途。

定义

>>> from sympy import Number, symbols, Add, Mul
>>> x, y, z = symbols('x y z')
>>> class FMA(Function):
...     """
...     FMA(x, y, z) = x*y + z
...
...     FMA is often defined as a single operation in hardware for better
...     rounding and performance.
...
...     FMA can be evaluated by using the doit() method.
...
...     References
...     ==========
...
...     .. [1] https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation#Fused_multiply%E2%80%93add
...     """
...     # Define automatic evaluation on explicit numbers
...     @classmethod
...     def eval(cls, x, y, z):
...         # Number is the base class of Integer, Rational, and Float
...         if all(isinstance(i, Number) for i in [x, y, z]):
...            return x*y + z
...
...     # Define numerical evaluation with evalf().
...     def _eval_evalf(self, prec):
...         return self.doit(deep=False)._eval_evalf(prec)
...
...     # Define full evaluation to Add and Mul in doit(). This effectively
...     # treats FMA(x, y, z) as just a shorthand for x*y + z that is useful
...     # to have as a separate expression in some contexts and which can be
...     # evaluated to its expanded form in other contexts.
...     def doit(self, deep=True, **hints):
...         x, y, z = self.args
...         # Recursively call doit() on the args whenever deep=True.
...         # Be sure to pass deep=True and **hints through here.
...         if deep:
...             x = x.doit(deep=deep, **hints)
...             y = y.doit(deep=deep, **hints)
...             z = z.doit(deep=deep, **hints)
...         return x*y + z
...
...     # Define FMA.rewrite(Add) and FMA.rewrite(Mul).
...     def _eval_rewrite(self, rule, args, **hints):
...         x, y, z = self.args
...         if rule in [Add, Mul]:
...             return self.doit()
...
...     # Define differentiation.
...     def fdiff(self, argindex):
...         # argindex indexes the args, starting at 1
...         x, y, z = self.args
...         if argindex == 1:
...             return y
...         elif argindex == 2:
...             return x
...         elif argindex == 3:
...             return 1
...
...     # Define code printers for ccode() and cxxcode()
...     def _ccode(self, printer):
...         x, y, z = self.args
...         _x, _y, _z = printer._print(x), printer._print(y), printer._print(z)
...         return "fma(%s, %s, %s)" % (_x, _y, _z)
...
...     def _cxxcode(self, printer):
...         x, y, z = self.args
...         _x, _y, _z = printer._print(x), printer._print(y), printer._print(z)
...         return "std::fma(%s, %s, %s)" % (_x, _y, _z)

示例

评估:

>>> x, y, z = symbols('x y z')
>>> FMA(2, 3, 4)
10
>>> FMA(x, y, z)
FMA(x, y, z)
>>> FMA(x, y, z).doit()
x*y + z
>>> FMA(x, y, z).rewrite(Add)
x*y + z
>>> FMA(2, pi, 1).evalf()
7.28318530717959

微分

>>> FMA(x, x, y).diff(x)
2*x
>>> FMA(x, y, x).diff(x)
y + 1

代码打印机

>>> from sympy import ccode, cxxcode
>>> ccode(FMA(x, y, z))
'fma(x, y, z)'
>>> cxxcode(FMA(x, y, z))
'std::fma(x, y, z)'

额外提示

  • SymPy 包含了数十个函数。这些函数可以作为如何编写自定义函数的实用示例,特别是当该函数与已实现的函数类似时。请记住,本指南中的所有内容同样适用于 SymPy 自带的函数和用户定义的函数。实际上,本指南旨在同时作为 SymPy 贡献者的开发者指南和 SymPy 终端用户的指南。

  • 如果你有许多自定义函数共享相同的逻辑,你可以使用一个公共基类来包含这些共享逻辑。例如,请参见 SymPy 中三角函数的源代码,其中使用了 TrigonometricFunctionInverseTrigonometricFunctionReciprocalTrigonometricFunction 基类,并包含了一些共享逻辑。

  • 与任何代码一样,为您的函数编写广泛的测试是一个好主意。SymPy 测试套件 是一个很好的资源,可以作为如何为这类函数编写测试的示例。SymPy 本身包含的所有代码都必须经过测试。包含在 SymPy 中的函数还应始终包含一个带有参考文献、数学定义和 doctest 示例的文档字符串。