模式:使用生成器减少堆内存使用#

在这个模式中,我们使用 Python 的 生成器 来减少任务执行期间的总体堆内存使用。关键思想是对于返回多个对象的任务,我们可以一次返回一个对象,而不是一次性全部返回。这使得工作线程在返回下一个对象之前可以释放前一个返回值所使用的堆内存。

示例用例#

你有一个返回多个大值的任务。另一种可能是返回单个大值的任务,但你希望通过将该值分解成较小的块来通过 Ray 的对象存储流式传输它。

使用普通的 Python 函数,我们可以这样编写这样的任务。以下是一个返回每个大小为 100MB 的 numpy 数组的示例:

import numpy as np


@ray.remote
def large_values(num_returns):
    return [
        np.random.randint(np.iinfo(np.int8).max, size=(100_000_000, 1), dtype=np.int8)
        for _ in range(num_returns)
    ]

然而,这将要求任务在任务结束时同时将所有 num_returns 数组保存在堆内存中。如果有许多返回值,这可能导致高堆内存使用率,并可能引发内存不足错误。

我们可以通过将 large_values 重写为一个 生成器 来修复上述示例。与其一次性返回所有值作为元组或列表,我们可以一次 yield 一个值。

@ray.remote
def large_values_generator(num_returns):
    for i in range(num_returns):
        yield np.random.randint(
            np.iinfo(np.int8).max, size=(100_000_000, 1), dtype=np.int8
        )
        print(f"yielded return value {i}")

代码示例#

import sys
import ray

# fmt: off
# __large_values_start__
import numpy as np


@ray.remote
def large_values(num_returns):
    return [
        np.random.randint(np.iinfo(np.int8).max, size=(100_000_000, 1), dtype=np.int8)
        for _ in range(num_returns)
    ]
# __large_values_end__
# fmt: on


# fmt: off
# __large_values_generator_start__
@ray.remote
def large_values_generator(num_returns):
    for i in range(num_returns):
        yield np.random.randint(
            np.iinfo(np.int8).max, size=(100_000_000, 1), dtype=np.int8
        )
        print(f"yielded return value {i}")
# __large_values_generator_end__
# fmt: on


# A large enough value (e.g. 100).
num_returns = int(sys.argv[1])
# Worker will likely OOM using normal returns.
print("Using normal functions...")
try:
    ray.get(
        large_values.options(num_returns=num_returns, max_retries=0).remote(
            num_returns
        )[0]
    )
except ray.exceptions.WorkerCrashedError:
    print("Worker failed with normal function")

# Using a generator will allow the worker to finish.
# Note that this will block until the full task is complete, i.e. the
# last yield finishes.
print("Using generators...")
ray.get(
    large_values_generator.options(num_returns=num_returns, max_retries=0).remote(
        num_returns
    )[0]
)
print("Success!")
$ RAY_IGNORE_UNHANDLED_ERRORS=1 python test.py 100

Using normal functions...
... -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker...
Worker failed
Using generators...
(large_values_generator pid=373609) yielded return value 0
(large_values_generator pid=373609) yielded return value 1
(large_values_generator pid=373609) yielded return value 2
...
Success!