调试 PySpark

PySpark 使用 Spark 作为引擎。PySpark 使用 Py4J 来利用 Spark 提交并计算作业。

在驱动端,PySpark通过使用 Py4J 与JVM上的驱动程序进行通信。当 pyspark.sql.SparkSession pyspark.SparkContext 被创建和初始化时,PySpark 启动 JVM进行通信。

在执行者端,Python 工作者执行和处理 Python 原生函数或数据。只有在 PySpark 应用程序需要 Python 工作者与 JVM 之间的交互时,它们才会被启动。它们仅在需要处理 Python 原生函数或数据时被懒惰地启动,例如,当您执行 pandas UDFs 或 PySpark RDD APIs 时。

该页面侧重于调试PySpark的Python端,包括驱动程序和执行程序,而不是集中于调试JVM。JVM的性能分析和调试在 有用的开发工具 中进行了描述。

请注意,

  • 如果您在本地运行,可以直接通过使用 IDE 调试驱动程序,而无需远程调试功能。使用 IDE 设置 PySpark 的文档 在这里

  • 还有许多其他调试 PySpark 应用程序的方法 。例如,您可以使用开源 远程调试器 进行远程调试,而不是使用这里文档中的 PyCharm Professional。

远程调试 (PyCharm 专业版)

本节描述了在单台机器上驱动程序和执行程序两侧的远程调试,以便于演示。 在执行程序侧调试PySpark的方法与在驱动程序中调试的方法不同。因此,它们将分别演示。 有关在其他机器上调试PySpark应用程序的说明,请参阅针对PyCharm的完整说明,文档记录在 这里

首先,从 运行 菜单中选择 编辑配置… 。这将打开 运行/调试配置对话框 。 您需要在工具栏上单击 + 配置,并从可用配置列表中选择 Python 调试服务器 。 输入此新配置的名称,例如 MyRemoteDebugger ,并指定端口号,例如 12345

PyCharm remote debugger setting
After that, you should install the corresponding version of the pydevd-pycharm package in all the machines which will connect to your PyCharm debugger. In the previous dialog, it shows the command to install.
pip install pydevd-pycharm~=<version of PyCharm on the local machine>

驾驶员侧

要在驱动程序端进行调试,您的应用程序应该能够连接到调试服务器。将带有 pydevd_pycharm.settrace 的代码复制并粘贴到您的 PySpark 脚本的顶部。假设脚本名称是 app.py :

echo "#======================Copy and paste from the previous dialog===========================
import pydevd_pycharm
pydevd_pycharm.settrace('localhost', port=12345, stdoutToServer=True, stderrToServer=True)
#========================================================================================
# Your PySpark application codes:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
spark.range(10).show()" > app.py

开始使用您的 MyRemoteDebugger 进行调试。

PyCharm run remote debugger
After that, submit your application. This will connect to your PyCharm debugging server and enable you to debug on the driver side remotely.
spark-submit app.py

执行端

要在执行器端进行调试,请在您当前的工作目录中准备如下Python文件。

echo "from pyspark import daemon, worker
def remote_debug_wrapped(*args, **kwargs):
    #======================Copy and paste from the previous dialog===========================
    import pydevd_pycharm
    pydevd_pycharm.settrace('localhost', port=12345, stdoutToServer=True, stderrToServer=True)
    #========================================================================================
    worker.main(*args, **kwargs)
daemon.worker_main = remote_debug_wrapped
if __name__ == '__main__':
    daemon.manager()" > remote_debug.py

您将使用此文件作为您的PySpark应用程序中的Python工作线程,通过使用 spark.python.daemon.module 配置来实现。
使用以下配置运行 pyspark shell:

pyspark --conf spark.python.daemon.module=remote_debug

现在你准备好进行远程调试了。开始使用你的 MyRemoteDebugger 进行调试。

PyCharm run remote debugger
After that, run a job that creates Python workers, for example, as below:
spark.range(10).repartition(1).rdd.map(lambda x: x).collect()

检查资源使用情况 ( top ps )

可以通过典型的方法检查驱动程序和执行器上的Python进程,例如 top ps 命令。

驱动程序侧

在驱动程序端,您可以通过以下方式轻松地从您的PySpark Shell获取进程ID,以了解进程ID和资源。

>>> import os; os.getpid()
18482
ps -fe 18482
UID   PID  PPID   C STIME  TTY           TIME CMD
000 18482 12345   0 0:00PM ttys001    0:00.00 /.../python

执行程序端

要检查执行器端,您可以简单地 grep 它们以找出进程 ID 和相关资源,因为 Python 工作进程是从 pyspark.daemon 生成的。

ps -fe | grep pyspark.daemon
000 12345     1   0  0:00PM ttys000    0:00.00 /.../python -m pyspark.daemon
000 12345     1   0  0:00PM ttys000    0:00.00 /.../python -m pyspark.daemon
000 12345     1   0  0:00PM ttys000    0:00.00 /.../python -m pyspark.daemon
000 12345     1   0  0:00PM ttys000    0:00.00 /.../python -m pyspark.daemon
...

内存使用分析(内存分析器)

memory_profiler 是其中一个分析器,可以让您逐行检查内存使用情况。

驱动侧

除非您在另一台机器上运行您的驱动程序(例如,YARN 集群模式),否则这个有用的工具可以用来轻松调试驱动端的内存使用情况。假设您的 PySpark 脚本名称是 profile_memory.py 。您可以如下进行分析。

echo "from pyspark.sql import SparkSession
#===Your function should be decorated with @profile===
from memory_profiler import profile
@profile
#=====================================================
def my_func():
    session = SparkSession.builder.getOrCreate()
    df = session.range(10000)
    return df.collect()
if __name__ == '__main__':
    my_func()" > profile_memory.py
python -m memory_profiler profile_memory.py
Filename: profile_memory.py

Line #    Mem usage    Increment   Line Contents
================================================
...
     6                             def my_func():
     7     51.5 MiB      0.6 MiB       session = SparkSession.builder.getOrCreate()
     8     51.5 MiB      0.0 MiB       df = session.range(10000)
     9     54.4 MiB      2.8 MiB       return df.collect()

Python/Pandas 用户定义函数

PySpark 提供了远程 memory_profiler 用于 Python/Pandas UDFs,可以通过将 spark.python.profile.memory 配置设置为 true 来启用。可以在带行号的编辑器上使用,例如 Jupyter 笔记本。下面是 Jupyter 笔记本的一个示例。

pyspark --conf spark.python.profile.memory=true
from pyspark.sql.functions import pandas_udf
df = spark.range(10)

@pandas_udf("long")
def add1(x):
  return x + 1

added = df.select(add1("id"))
added.show()
sc.show_profiles()

结果配置文件如下所示。

============================================================
Profile of UDF<id=2>
============================================================
Filename: ...

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
     4    974.0 MiB    974.0 MiB          10   @pandas_udf("long")
     5                                         def add1(x):
     6    974.4 MiB      0.4 MiB          10     return x + 1

UDF ID可以在查询计划中看到,例如, add1(...)#2L ArrowEvalPython 中,如下所示。

added.explain()
== Physical Plan ==
*(2) Project [pythonUDF0#11L AS add1(id)#3L]
+- ArrowEvalPython [add1(id#0L)#2L], [pythonUDF0#11L], 200
   +- *(1) Range (0, 10, step=1, splits=16)

此功能不支持注册的 UDF 或作为输入/输出的具有迭代器的 UDF。

识别热点循环 (Python 性能分析工具)

Python Profilers 是 Python 自身的有用内置功能。这些功能提供了对 Python 程序的确定性分析,并带有许多有用的统计数据。本节描述了如何在驱动程序和执行器两侧使用它,以便识别高开销或热点代码路径。

驱动程序端

要在驱动程序端使用它,您可以像对待常规 Python 程序一样使用它,因为驱动程序端的 PySpark 是一个常规的 Python 进程,除非您在另一台机器(例如,YARN 集群模式)上运行驱动程序。

echo "from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
spark.range(10).show()" > app.py
python -m cProfile app.py
...
     129215 function calls (125446 primitive calls) in 5.926 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
 1198/405    0.001    0.000    0.083    0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
      561    0.001    0.000    0.001    0.000 <frozen importlib._bootstrap>:103(release)
      276    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:143(__init__)
      276    0.000    0.000    0.002    0.000 <frozen importlib._bootstrap>:147(__enter__)
...

执行器端

要在执行器端使用此功能,PySpark 提供了远程 Python Profilers ,可以通过将 spark.python.profile 配置设置为 true 来启用。

pyspark --conf spark.python.profile=true
>>> rdd = sc.parallelize(range(100)).map(str)
>>> rdd.count()
100
>>> sc.show_profiles()
============================================================
Profile of RDD<id=1>
============================================================
         728 function calls (692 primitive calls) in 0.004 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       12    0.001    0.000    0.001    0.000 serializers.py:210(load_stream)
       12    0.000    0.000    0.000    0.000 {built-in method _pickle.dumps}
       12    0.000    0.000    0.001    0.000 serializers.py:252(dump_stream)
       12    0.000    0.000    0.001    0.000 context.py:506(f)
...

Python/Pandas 用户定义函数

要在Python/Pandas UDFs中使用此功能,PySpark为Python/Pandas UDFs提供了远程 Python Profilers ,可以通过将 spark.python.profile 配置设置为 true 来启用。

pyspark --conf spark.python.profile=true
>>> from pyspark.sql.functions import pandas_udf
>>> df = spark.range(10)
>>> @pandas_udf("long")
... def add1(x):
...     return x + 1
...
>>> added = df.select(add1("id"))

>>> added.show()
+--------+
|add1(id)|
+--------+
...
+--------+

>>> sc.show_profiles()
============================================================
Profile of UDF<id=2>
============================================================
         2300 function calls (2270 primitive calls) in 0.006 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       10    0.001    0.000    0.005    0.001 series.py:5515(_arith_method)
       10    0.001    0.000    0.001    0.000 _ufunc_config.py:425(__init__)
       10    0.000    0.000    0.000    0.000 {built-in method _operator.add}
       10    0.000    0.000    0.002    0.000 series.py:315(__init__)
...

可以在查询计划中看到UDF ID,例如, add1(...)#2L 在下面的 ArrowEvalPython 中。

>>> added.explain()
== Physical Plan ==
*(2) Project [pythonUDF0#11L AS add1(id)#3L]
+- ArrowEvalPython [add1(id#0L)#2L], [pythonUDF0#11L], 200
   +- *(1) Range (0, 10, step=1, splits=16)

此功能不支持已注册的 UDF。

常见异常 / 错误

PySpark SQL

分析异常

AnalysisException 在分析 SQL 查询计划失败时引发。

示例:

>>> df = spark.range(1)
>>> df['bad_key']
Traceback (most recent call last):
...
pyspark.errors.exceptions.AnalysisException: Cannot resolve column name "bad_key" among (id)

解决方案:

>>> df['id']
Column<'id'>

解析异常

ParseException 在解析 SQL 命令失败时被抛出。

示例:

>>> spark.sql("select * 1")
Traceback (most recent call last):
...
pyspark.errors.exceptions.ParseException:
[PARSE_SYNTAX_ERROR] Syntax error at or near '1': extra input '1'.(line 1, pos 9)

== SQL ==
select * 1
---------^^^

解决方案:

>>> spark.sql("select *")
DataFrame[]

非法参数异常

IllegalArgumentException 在传递非法或不适当的参数时被引发。

示例:

>>> spark.range(1).sample(-1.0)
Traceback (most recent call last):
...
pyspark.errors.exceptions.IllegalArgumentException: requirement failed: Sampling fraction (-1.0) must be on interval [0, 1] without replacement

解决方案:

>>> spark.range(1).sample(1.0)
DataFrame[id: bigint]

Python异常

PythonException 是从Python工作进程中抛出的。

您可以看到从Python工作中抛出的异常类型及其堆栈跟踪,如下所示 TypeError

示例:

>>> import pyspark.sql.functions as sf
>>> from pyspark.sql.functions import udf
>>> def f(x):
...   return sf.abs(x)
...
>>> spark.range(-1, 1).withColumn("abs", udf(f)("id")).collect()
22/04/12 14:52:31 ERROR Executor: Exception in task 7.0 in stage 37.0 (TID 232)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
...
TypeError: Invalid argument, not a string or column: -1 of type <class 'int'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' function.

解决方案:

>>> def f(x):
...   return abs(x)
...
>>> spark.range(-1, 1).withColumn("abs", udf(f)("id")).collect()
[Row(id=-1, abs='1'), Row(id=0, abs='0')]

流查询异常

StreamingQueryException 在StreamingQuery失败时引发。它通常由Python工作进程抛出,并将其包装为 PythonException

示例:

>>> sdf = spark.readStream.format("text").load("python/test_support/sql/streaming")
>>> from pyspark.sql.functions import col, udf
>>> bad_udf = udf(lambda x: 1 / 0)
>>> (sdf.select(bad_udf(col("value"))).writeStream.format("memory").queryName("q1").start()).processAllAvailable()
Traceback (most recent call last):
...
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "<stdin>", line 1, in <lambda>
ZeroDivisionError: division by zero
...
pyspark.errors.exceptions.StreamingQueryException: [STREAM_FAILED] Query [id = 74eb53a8-89bd-49b0-9313-14d29eed03aa, runId = 9f2d5cf6-a373-478d-b718-2c2b6d8a0f24] terminated with exception: Job aborted

解决方案:

修复StreamingQuery并重新执行工作流程。

Spark升级异常

SparkUpgradeException 是因为 Spark 升级而抛出的异常。

示例:

>>> from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
>>> df = spark.createDataFrame([("2014-31-12",)], ["date_str"])
>>> df2 = df.select("date_str", to_date(from_unixtime(unix_timestamp("date_str", "yyyy-dd-aa"))))
>>> df2.collect()
Traceback (most recent call last):
...
pyspark.sql.utils.SparkUpgradeException: You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'yyyy-dd-aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html

解决方案:

>>> spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")
>>> df2 = df.select("date_str", to_date(from_unixtime(unix_timestamp("date_str", "yyyy-dd-aa"))))
>>> df2.collect()
[Row(date_str='2014-31-12', to_date(from_unixtime(unix_timestamp(date_str, yyyy-dd-aa), yyyy-MM-dd HH:mm:ss))=None)]

基于Spark的pandas API

在Spark的pandas API中有一些特定的常见异常/错误。

值错误:无法合并系列或数据框,因为它来自不同的数据框

涉及多个系列或数据框的操作会引发一个 ValueError 如果 compute.ops_on_diff_frames 被禁用(默认情况下禁用)。由于底层 Spark 框架的连接,这些操作可能会很昂贵。因此,用户应注意费用,并仅在必要时启用该标志。

异常:

>>> ps.Series([1, 2]) + ps.Series([3, 4])
Traceback (most recent call last):
...
ValueError: Cannot combine the series or dataframe because it comes from a different dataframe. In order to allow this operation, enable 'compute.ops_on_diff_frames' option.

解决方案:

>>> with ps.option_context('compute.ops_on_diff_frames', True):
...     ps.Series([1, 2]) + ps.Series([3, 4])
...
0    4
1    6
dtype: int64

运行时错误:pandas_udf 的结果向量不是所需的长度

异常:

>>> def f(x) -> ps.Series[np.int32]:
...   return x[:-1]
...
>>> ps.DataFrame({"x":[1, 2], "y":[3, 4]}).transform(f)
22/04/12 13:46:39 ERROR Executor: Exception in task 2.0 in stage 16.0 (TID 88)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
...
RuntimeError: Result vector from pandas_udf was not the required length: expected 1, got 0

解决方案:

>>> def f(x) -> ps.Series[np.int32]:
...   return x
...
>>> ps.DataFrame({"x":[1, 2], "y":[3, 4]}).transform(f)
   x  y
0  1  3
1  2  4

Py4j

Py4JJavaError

Py4JJavaError 在Java客户端代码中发生异常时被引发。您可以看到在Java端抛出的异常类型及其堆栈跟踪,如下所示: java.lang.NullPointerException

示例:

>>> spark.sparkContext._jvm.java.lang.String(None)
Traceback (most recent call last):
...
py4j.protocol.Py4JJavaError: An error occurred while calling None.java.lang.String.
: java.lang.NullPointerException
..

解决方案:

>>> spark.sparkContext._jvm.java.lang.String("x")
'x'

Py4J错误

Py4JError 当发生其他错误时被引发,比如当 Python 客户端程序尝试访问 Java 端不再存在的对象时。

示例:

>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.regression import LinearRegression
>>> df = spark.createDataFrame(
...             [(1.0, 2.0, Vectors.dense(1.0)), (0.0, 2.0, Vectors.sparse(1, [], []))],
...             ["label", "weight", "features"],
...         )
>>> lr = LinearRegression(
...             maxIter=1, regParam=0.0, solver="normal", weightCol="weight", fitIntercept=False
...         )
>>> model = lr.fit(df)
>>> model
LinearRegressionModel: uid=LinearRegression_eb7bc1d4bf25, numFeatures=1
>>> model.__del__()
>>> model
Traceback (most recent call last):
...
py4j.protocol.Py4JError: An error occurred while calling o531.toString. Trace:
py4j.Py4JException: Target Object ID does not exist for this gateway :o531
...

解决方案:

访问一个存在于Java端的对象。

Py4J网络错误

Py4JNetworkError 是在网络传输过程中出现问题时引发的错误(例如,连接丢失)。在这种情况下,我们将调试网络并重建连接。

堆栈跟踪

有Spark配置来控制堆栈跟踪:

  • spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled 的默认值为 true,以简化来自 Python UDF 的 traceback。

  • spark.sql.pyspark.jvmStacktrace.enabled 的默认值为 false,以隐藏 JVM stacktrace 并仅显示 Python 友好的异常。

上面的Spark配置与日志级别设置是独立的。通过 pyspark.SparkContext.setLogLevel() 控制日志级别。