高级表达式操作

在本节中,我们将讨论一些可以执行表达式高级操作的方法。

理解表达式树

在我们能够这样做之前,我们需要理解表达式在 SymPy 中是如何表示的。一个数学表达式被表示为一棵树。让我们以表达式 \(x^2 + xy\) 为例,即 x**2 + x*y。我们可以通过使用 srepr 来查看这个表达式在内部的样子。

>>> from sympy import *
>>> x, y, z = symbols('x y z')
>>> expr = x**2 + x*y
>>> srepr(expr)
"Add(Pow(Symbol('x'), Integer(2)), Mul(Symbol('x'), Symbol('y')))"

分解这个最简单的方法是查看表达式树的图示:

digraph{ # Graph style "ordering"="out" "rankdir"="TD" ######### # Nodes # ######### "Add(Pow(Symbol('x'), Integer(2)), Mul(Symbol('x'), Symbol('y')))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; "Pow(Symbol('x'), Integer(2))_(0,)" ["color"="black", "label"="Pow", "shape"="ellipse"]; "Symbol('x')_(0, 0)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Integer(2)_(0, 1)" ["color"="black", "label"="Integer(2)", "shape"="ellipse"]; "Mul(Symbol('x'), Symbol('y'))_(1,)" ["color"="black", "label"="Mul", "shape"="ellipse"]; "Symbol('x')_(1, 0)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Symbol('y')_(1, 1)" ["color"="black", "label"="Symbol('y')", "shape"="ellipse"]; ######### # Edges # ######### "Add(Pow(Symbol('x'), Integer(2)), Mul(Symbol('x'), Symbol('y')))_()" -> "Pow(Symbol('x'), Integer(2))_(0,)"; "Add(Pow(Symbol('x'), Integer(2)), Mul(Symbol('x'), Symbol('y')))_()" -> "Mul(Symbol('x'), Symbol('y'))_(1,)"; "Pow(Symbol('x'), Integer(2))_(0,)" -> "Symbol('x')_(0, 0)"; "Pow(Symbol('x'), Integer(2))_(0,)" -> "Integer(2)_(0, 1)"; "Mul(Symbol('x'), Symbol('y'))_(1,)" -> "Symbol('x')_(1, 0)"; "Mul(Symbol('x'), Symbol('y'))_(1,)" -> "Symbol('y')_(1, 1)"; }

备注

上面的图表是使用 Graphvizdotprint 函数制作的。

首先,让我们看看这棵树的叶子。符号是类 Symbol 的实例。虽然我们一直在做

>>> x = symbols('x')

我们也可以这样做

>>> x = Symbol('x')

无论哪种方式,我们都会得到一个名为“x”的符号 [1]。对于表达式中的数字2,我们得到了 Integer(2)Integer 是 SymPy 中用于整数的类。它类似于 Python 内置类型 int,只不过 Integer 与其他 SymPy 类型配合得很好。

当我们写 x**2 时,这会创建一个 Pow 对象。 Pow 是“power”的缩写。

>>> srepr(x**2)
"Pow(Symbol('x'), Integer(2))"

我们可以通过调用 Pow(x, 2) 来创建相同的对象。

>>> Pow(x, 2)
x**2

请注意,在 srepr 输出中,我们看到 Integer(2),这是 SymPy 版本的整数,尽管从技术上讲,我们输入的是 2,一个 Python 整数。通常,每当你通过某些函数或操作将 SymPy 对象与非 SymPy 对象结合时,非 SymPy 对象将被转换为 SymPy 对象。执行此操作的函数是 sympify [2]

>>> type(2)
<... 'int'>
>>> type(sympify(2))
<class 'sympy.core.numbers.Integer'>

我们已经看到 x**2 被表示为 Pow(x, 2)。那么 x*y 呢?正如我们所预期的,这是 xy 的乘法。SymPy 中乘法的类是 Mul

>>> srepr(x*y)
"Mul(Symbol('x'), Symbol('y'))"

因此,我们可以通过编写 Mul(x, y) 来创建相同的对象。

>>> Mul(x, y)
x*y

现在我们得到了最终的表达式,x**2 + x*y。这是我们最后两个对象 Pow(x, 2)Mul(x, y) 的加法。SymPy 中加法的类是 Add,所以,正如你可能预期的那样,要创建这个对象,我们使用 Add(Pow(x, 2), Mul(x, y))

>>> Add(Pow(x, 2), Mul(x, y))
x**2 + x*y

SymPy 表达式树可以有很多分支,并且可以非常深或非常广。这里是一个更复杂的例子

>>> expr = sin(x*y)/2 - x**2 + 1/y
>>> srepr(expr)
"Add(Mul(Integer(-1), Pow(Symbol('x'), Integer(2))), Mul(Rational(1, 2),
sin(Mul(Symbol('x'), Symbol('y')))), Pow(Symbol('y'), Integer(-1)))"

这里是一个图表

digraph{ # Graph style "rankdir"="TD" ######### # Nodes # ######### "Half()_(0, 0)" ["color"="black", "label"="Rational(1, 2)", "shape"="ellipse"]; "Symbol(y)_(2, 0)" ["color"="black", "label"="Symbol('y')", "shape"="ellipse"]; "Symbol(x)_(1, 1, 0)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Integer(2)_(1, 1, 1)" ["color"="black", "label"="Integer(2)", "shape"="ellipse"]; "NegativeOne()_(2, 1)" ["color"="black", "label"="Integer(-1)", "shape"="ellipse"]; "NegativeOne()_(1, 0)" ["color"="black", "label"="Integer(-1)", "shape"="ellipse"]; "Symbol(y)_(0, 1, 0, 1)" ["color"="black", "label"="Symbol('y')", "shape"="ellipse"]; "Symbol(x)_(0, 1, 0, 0)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Pow(Symbol(x), Integer(2))_(1, 1)" ["color"="black", "label"="Pow", "shape"="ellipse"]; "Pow(Symbol(y), NegativeOne())_(2,)" ["color"="black", "label"="Pow", "shape"="ellipse"]; "Mul(Symbol(x), Symbol(y))_(0, 1, 0)" ["color"="black", "label"="Mul", "shape"="ellipse"]; "sin(Mul(Symbol(x), Symbol(y)))_(0, 1)" ["color"="black", "label"="sin", "shape"="ellipse"]; "Mul(Half(), sin(Mul(Symbol(x), Symbol(y))))_(0,)" ["color"="black", "label"="Mul", "shape"="ellipse"]; "Mul(NegativeOne(), Pow(Symbol(x), Integer(2)))_(1,)" ["color"="black", "label"="Mul", "shape"="ellipse"]; "Add(Mul(Half(), sin(Mul(Symbol(x), Symbol(y)))), Mul(NegativeOne(), Pow(Symbol(x), Integer(2))), Pow(Symbol(y), NegativeOne()))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; ######### # Edges # ######### "Pow(Symbol(y), NegativeOne())_(2,)" -> "Symbol(y)_(2, 0)"; "Pow(Symbol(x), Integer(2))_(1, 1)" -> "Symbol(x)_(1, 1, 0)"; "Pow(Symbol(x), Integer(2))_(1, 1)" -> "Integer(2)_(1, 1, 1)"; "Pow(Symbol(y), NegativeOne())_(2,)" -> "NegativeOne()_(2, 1)"; "Mul(Symbol(x), Symbol(y))_(0, 1, 0)" -> "Symbol(x)_(0, 1, 0, 0)"; "Mul(Symbol(x), Symbol(y))_(0, 1, 0)" -> "Symbol(y)_(0, 1, 0, 1)"; "Mul(Half(), sin(Mul(Symbol(x), Symbol(y))))_(0,)" -> "Half()_(0, 0)"; "Mul(NegativeOne(), Pow(Symbol(x), Integer(2)))_(1,)" -> "NegativeOne()_(1, 0)"; "sin(Mul(Symbol(x), Symbol(y)))_(0, 1)" -> "Mul(Symbol(x), Symbol(y))_(0, 1, 0)"; "Mul(NegativeOne(), Pow(Symbol(x), Integer(2)))_(1,)" -> "Pow(Symbol(x), Integer(2))_(1, 1)"; "Mul(Half(), sin(Mul(Symbol(x), Symbol(y))))_(0,)" -> "sin(Mul(Symbol(x), Symbol(y)))_(0, 1)"; "Add(Mul(Half(), sin(Mul(Symbol(x), Symbol(y)))), Mul(NegativeOne(), Pow(Symbol(x), Integer(2))), Pow(Symbol(y), NegativeOne()))_()" -> "Pow(Symbol(y), NegativeOne())_(2,)"; "Add(Mul(Half(), sin(Mul(Symbol(x), Symbol(y)))), Mul(NegativeOne(), Pow(Symbol(x), Integer(2))), Pow(Symbol(y), NegativeOne()))_()" -> "Mul(Half(), sin(Mul(Symbol(x), Symbol(y))))_(0,)"; "Add(Mul(Half(), sin(Mul(Symbol(x), Symbol(y)))), Mul(NegativeOne(), Pow(Symbol(x), Integer(2))), Pow(Symbol(y), NegativeOne()))_()" -> "Mul(NegativeOne(), Pow(Symbol(x), Integer(2)))_(1,)"; }

这个表达式揭示了SymPy表达式树的一些有趣的东西。让我们逐一探讨。

首先来看术语 x**2。正如我们所预期的,我们看到 Pow(x, 2)。再上一层,我们看到我们有 Mul(-1, Pow(x, 2))。在SymPy中没有减法类。x - y 表示为 x + -y,或者更完整地,x + -1*y,即 Add(x, Mul(-1, y))

>>> srepr(x - y)
"Add(Symbol('x'), Mul(Integer(-1), Symbol('y')))"

digraph{ # Graph style "rankdir"="TD" ######### # Nodes # ######### "Symbol(x)_(1,)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Symbol(y)_(0, 1)" ["color"="black", "label"="Symbol('y')", "shape"="ellipse"]; "NegativeOne()_(0, 0)" ["color"="black", "label"="Integer(-1)", "shape"="ellipse"]; "Mul(NegativeOne(), Symbol(y))_(0,)" ["color"="black", "label"="Mul", "shape"="ellipse"]; "Add(Mul(NegativeOne(), Symbol(y)), Symbol(x))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; ######### # Edges # ######### "Mul(NegativeOne(), Symbol(y))_(0,)" -> "Symbol(y)_(0, 1)"; "Mul(NegativeOne(), Symbol(y))_(0,)" -> "NegativeOne()_(0, 0)"; "Add(Mul(NegativeOne(), Symbol(y)), Symbol(x))_()" -> "Symbol(x)_(1,)"; "Add(Mul(NegativeOne(), Symbol(y)), Symbol(x))_()" -> "Mul(NegativeOne(), Symbol(y))_(0,)"; }

接下来,看看 1/y。我们可能会期待看到类似 Div(1, y) 的东西,但与减法类似,SymPy 中没有用于除法的类。相反,除法由 -1 次幂表示。因此,我们有 Pow(y, -1)。如果我们用 y 除以 1 以外的其他东西,比如 x/y,会怎么样呢?让我们看看。

>>> expr = x/y
>>> srepr(expr)
"Mul(Symbol('x'), Pow(Symbol('y'), Integer(-1)))"

digraph{ # Graph style "rankdir"="TD" ######### # Nodes # ######### "Symbol(x)_(0,)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Symbol(y)_(1, 0)" ["color"="black", "label"="Symbol('y')", "shape"="ellipse"]; "NegativeOne()_(1, 1)" ["color"="black", "label"="Integer(-1)", "shape"="ellipse"]; "Pow(Symbol(y), NegativeOne())_(1,)" ["color"="black", "label"="Pow", "shape"="ellipse"]; "Mul(Symbol(x), Pow(Symbol(y), NegativeOne()))_()" ["color"="black", "label"="Mul", "shape"="ellipse"]; ######### # Edges # ######### "Pow(Symbol(y), NegativeOne())_(1,)" -> "Symbol(y)_(1, 0)"; "Pow(Symbol(y), NegativeOne())_(1,)" -> "NegativeOne()_(1, 1)"; "Mul(Symbol(x), Pow(Symbol(y), NegativeOne()))_()" -> "Symbol(x)_(0,)"; "Mul(Symbol(x), Pow(Symbol(y), NegativeOne()))_()" -> "Pow(Symbol(y), NegativeOne())_(1,)"; }

我们看到 x/y 表示为 x*y**-1,即 Mul(x, Pow(y, -1))

最后,让我们看一下 sin(x*y)/2 这一项。按照前一个例子的模式,我们可能会期望看到 Mul(sin(x*y), Pow(Integer(2), -1))。但实际上,我们得到的是 Mul(Rational(1, 2), sin(x*y))。有理数在乘法中总是被合并成一个单独的项,因此当我们除以2时,它被表示为乘以1/2。

最后,还有一个注意事项。你可能已经注意到,我们输入表达式的顺序与从 srepr 或图中输出的顺序是不同的。在教程的早期部分,你可能也注意到了这种现象。例如

>>> 1 + x
x + 1

这是因为,在 SymPy 中,交换操作 AddMul 的参数以任意(但一致!)顺序存储,这与输入顺序无关(如果你担心非交换乘法,不用担心。在 SymPy 中,你可以使用 Symbol('A', commutative=False) 创建非交换符号,并且非交换符号的乘法顺序与输入顺序保持一致)。此外,正如我们将在下一节中看到的,打印顺序和内部存储顺序也不必相同。

通常,在使用 SymPy 表达式树时需要记住的一件重要事情是:表达式的内部表示和它的打印方式不一定相同。输入形式也是如此。如果某些表达式操作算法没有按照您预期的方式工作,很可能是因为对象的内部表示与您认为的不同。

递归遍历表达式树

既然你知道了 SymPy 中表达式树的工作原理,让我们来看看如何深入探索一个表达式树。SymPy 中的每个对象都有两个非常重要的属性,funcargs

函数

func 是对象的头部。例如,(x*y).funcMul。通常它与对象的类相同(尽管这条规则有例外)。

关于 func 的两点说明。首先,对象的类不必与创建它时使用的类相同。例如

>>> expr = Add(x, x)
>>> expr.func
<class 'sympy.core.mul.Mul'>

我们创建了 Add(x, x),因此我们可能期望 expr.funcAdd,但我们得到的却是 Mul。为什么会这样?让我们更仔细地看看 expr

>>> expr
2*x

Add(x, x),即 x + x,被自动转换为 Mul(2, x),即 2*x,这是一个 Mul。SymPy 类大量使用 __new__ 类构造函数,它与 __init__ 不同,允许从构造函数返回不同的类。

其次,一些类是特殊处理的,通常是为了效率原因 [3]

>>> Integer(2).func
<class 'sympy.core.numbers.Integer'>
>>> Integer(0).func
<class 'sympy.core.numbers.Zero'>
>>> Integer(-1).func
<class 'sympy.core.numbers.NegativeOne'>

在大多数情况下,这些问题不会困扰我们。特殊类 ZeroOneNegativeOne 等是 Integer 的子类,因此只要使用 isinstance,就不会有问题。

参数

args 是对象的顶级参数。 (x*y).args 将是 (x, y)。 让我们看一些例子

>>> expr = 3*y**2*x
>>> expr.func
<class 'sympy.core.mul.Mul'>
>>> expr.args
(3, x, y**2)

由此,我们可以看到 expr == Mul(3, y**2, x)。 事实上,我们可以看到我们可以完全从其 funcargs 重建 expr

>>> expr.func(*expr.args)
3*x*y**2
>>> expr == expr.func(*expr.args)
True

注意,尽管我们输入了 3*y**2*x,但 args(3, x, y**2)。在 Mul 中,有理系数将首先出现在 args 中,但除此之外,其他所有内容的顺序都没有特定的模式。不过,可以确定的是,确实存在一个顺序。

>>> expr = y**2*3*x
>>> expr.args
(3, x, y**2)

Mul 的 args 是排序的,因此相同的 Mul 将具有相同的 args。但是排序是基于一些旨在使排序唯一且高效的准则进行的,这些准则没有数学意义。

我们的 exprsrepr 形式是 Mul(3, x, Pow(y, 2))。如果我们想要获取 Pow(y, 2)args。注意 y**2expr.args 的第三个位置,即 expr.args[2]

>>> expr.args[2]
y**2

因此,要获取这个的 args ,我们调用 expr.args[2].args

>>> expr.args[2].args
(y, 2)

现在如果我们尝试更深入。y 的参数是什么。或者 2。让我们看看。

>>> y.args
()
>>> Integer(2).args
()

它们都有空的 args。在 SymPy 中,空的 args 表示我们已经到达了表达式树的叶子节点。

因此,SymPy 表达式有两种可能性。要么它的 args 为空,在这种情况下,它是任何表达式树中的叶子节点,要么它有 args,在这种情况下,它是任何表达式树中的分支节点。当它有 args 时,它可以完全从其 funcargs 重建。这一点在关键不变式中得到了体现。

(回忆一下,在Python中如果 a 是一个元组,那么 f(*a) 意味着用 a 的元素作为参数来调用 f ,例如, f(*(1, 2, 3)) 等同于 f(1, 2, 3) 。)

这个关键的不变性使我们能够编写简单的算法来遍历表达式树,改变它们,并将它们重新构建为新的表达式。

遍历树

有了这些知识,让我们看看如何递归遍历一个表达式树。args 的嵌套结构非常适合递归函数。基本情况将是空的 args。让我们编写一个简单的函数,遍历一个表达式并在每一层打印所有的 args

>>> def pre(expr):
...     print(expr)
...     for arg in expr.args:
...         pre(arg)

看看 () 在表达式树中表示叶子是多么好。我们甚至不需要为递归写一个基本情况;它由for循环自动处理。

让我们测试我们的功能。

>>> expr = x*y + 1
>>> pre(expr)
x*y + 1
1
x*y
x
y

你能猜到我们为什么把我们的函数命名为 pre 吗?我们刚刚为我们的表达式树写了一个前序遍历函数。看看你是否能写一个后序遍历函数。

这种遍历在 SymPy 中非常常见,以至于提供了生成器函数 preorder_traversalpostorder_traversal 来简化这种遍历。我们也可以将我们的算法写成

>>> for arg in preorder_traversal(expr):
...     print(arg)
x*y + 1
1
x*y
x
y

防止表达式求值

通常有两种方法来防止求值,一种是在构造表达式时传递 evaluate=False 参数,另一种是通过用 UnevaluatedExpr 包装表达式来创建一个求值停止器。

例如:

>>> from sympy import Add
>>> from sympy.abc import x, y, z
>>> x + x
2*x
>>> Add(x, x)
2*x
>>> Add(x, x, evaluate=False)
x + x

如果你不记得对应于你想构建的表达式的类(运算符重载通常假设 evaluate=True),只需使用 sympify 并传递一个字符串:

>>> from sympy import sympify
>>> sympify("x + x", evaluate=False)
x + x

注意,evaluate=False 不会阻止表达式在后续使用中的未来评估:

>>> expr = Add(x, x, evaluate=False)
>>> expr
x + x
>>> expr + x
3*x

这就是为什么 UnevaluatedExpr 类非常有用。UnevaluatedExpr 是 SymPy 提供的一个方法,它允许用户保持表达式未求值。所谓 未求值 是指其中的值不会与外部的表达式交互以产生简化的输出。例如:

>>> from sympy import UnevaluatedExpr
>>> expr = x + UnevaluatedExpr(x)
>>> expr
x + x
>>> x + expr
2*x + x

单独的 \(x\) 是被 UnevaluatedExpr 包裹的 \(x\)。要释放它:

>>> (x + expr).doit()
3*x

其他示例:

>>> from sympy import *
>>> from sympy.abc import x, y, z
>>> uexpr = UnevaluatedExpr(S.One*5/7)*UnevaluatedExpr(S.One*3/4)
>>> uexpr
(5/7)*(3/4)
>>> x*UnevaluatedExpr(1/x)
x*1/x

需要注意的是,UnevaluatedExpr 不能阻止作为参数给出的表达式的求值。例如:

>>> expr1 = UnevaluatedExpr(x + x)
>>> expr1
2*x
>>> expr2 = sympify('x + x', evaluate=False)
>>> expr2
x + x

记住,如果 expr2 被包含在另一个表达式中,它将被评估。结合这两种方法以防止内部和外部的评估:

>>> UnevaluatedExpr(sympify("x + x", evaluate=False)) + y
y + (x + x)

UnevaluatedExpr 被 SymPy 打印机支持,可以用于以不同的输出形式打印结果。例如

>>> from sympy import latex
>>> uexpr = UnevaluatedExpr(S.One*5/7)*UnevaluatedExpr(S.One*3/4)
>>> print(latex(uexpr))
\frac{5}{7} \cdot \frac{3}{4}

为了释放表达式并获取评估后的 LaTeX 形式,只需使用 .doit()

>>> print(latex(uexpr.doit()))
\frac{15}{28}

脚注