torch.distributed.elastic.agent.server.local_elastic_agent 的源代码
#!/usr/bin/env python3
# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在源代码树的根目录中的LICENSE文件中找到的BSD风格许可证进行许可。
import json
import os
import signal
import socket
from string import Template
import uuid
from typing import Any, Dict, Optional, Tuple
import torch.distributed.elastic.timer as timer
from torch.distributed.elastic import events
from torch.distributed.elastic.agent.server.api import (
RunResult,
SimpleElasticAgent,
WorkerGroup,
WorkerSpec,
WorkerState,
)
from torch.distributed.elastic.events.api import EventMetadataValue
from torch.distributed.elastic.metrics.api import prof
from torch.distributed.elastic.multiprocessing import PContext, start_processes, LogsSpecs
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger
log = get_logger(__name__)
__all__ = [
"LocalElasticAgent",
"TORCHELASTIC_ENABLE_FILE_TIMER",
"TORCHELASTIC_TIMER_FILE",
]
TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER"
TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE"
[docs]class LocalElasticAgent(SimpleElasticAgent):
""":py:class:`torchelastic.agent.server.ElasticAgent`的实现,处理主机本地的工作者。
此代理按主机部署,并配置为生成``n``个工作者。
当使用GPU时,``n``映射到主机上可用的GPU数量。
本地代理不会与其他主机上部署的本地代理通信,即使工作者可能进行跨主机通信。
工作者ID被解释为本地进程。代理将所有工作者进程作为一个单元启动和停止。
传递给工作者函数的工作者函数和参数必须是python多处理兼容的。
要向工作者传递多处理数据结构,您可以在与指定的``start_method``相同的多处理上下文中创建数据结构,并将其作为函数参数传递。
``exit_barrier_timeout``指定等待其他代理完成的时间量(以秒为单位)。
这作为处理工作者在不同时间完成情况的安全网,以防止代理将提前完成的工作者视为缩减事件。
强烈建议用户代码确保工作者以同步方式终止,而不是依赖exit_barrier_timeout。
如果在```LocalElasticAgent```进程中定义了值为1的环境变量``TORCHELASTIC_ENABLE_FILE_TIMER``,则可以在```LocalElasticAgent```中启用基于命名管道的看门狗。
可选地,另一个环境变量```TORCHELASTIC_TIMER_FILE```可以设置为命名管道的唯一文件名。
如果未设置环境变量```TORCHELASTIC_TIMER_FILE```,```LocalElasticAgent```将在内部创建一个唯一文件名并将其设置为环境变量```TORCHELASTIC_TIMER_FILE```,
此环境变量将传播到工作者进程,以允许它们连接到```LocalElasticAgent```使用的相同命名管道。
日志将写入指定的日志目录。每条日志行将默认以``[${role_name}${local_rank}]:``为前缀(例如``[trainer0]: foobar``)。
日志前缀可以通过传递一个`模板字符串
`_作为``log_line_prefix_template``参数来自定义。
以下宏(标识符)在运行时替换:
``${role_name}, ${local_rank}, ${rank}``。例如,要将每条日志行前缀设置为全局排名而不是本地排名,请设置``log_line_prefix_template = "[${rank}]:``。
示例启动函数
::
def trainer(args) -> str:
return "do train"
def main():
start_method="spawn"
shared_queue= multiprocessing.get_context(start_method).Queue()
spec = WorkerSpec(
role="trainer",
local_world_size=nproc_per_process,
entrypoint=trainer,
args=("foobar",),
...)
agent = LocalElasticAgent(spec, start_method)
results = agent.run()
if results.is_failed():
print("trainer failed")
else:
print(f"rank 0 return value: {results.return_values[0]}")
# prints -> rank 0 return value: do train
示例启动二进制文件
::
def main():
spec = WorkerSpec(
role="trainer",
local_world_size=nproc_per_process,
entrypoint="/usr/local/bin/trainer",
args=("--trainer-args", "foobar"),
...)
agent = LocalElasticAgent(spec)
results = agent.run()
if not results.is_failed():
print("binary launches do not have return values")
"""
def __init__(
self,
spec: WorkerSpec,
logs_specs: LogsSpecs,
start_method="spawn",
exit_barrier_timeout: float = 300,
log_line_prefix_template: Optional[str] = None,
):
super().__init__(spec, exit_barrier_timeout)
self._start_method = start_method
self._pcontext: Optional[PContext] = None
self._rdzv_handler = spec.rdzv_handler
self._log_line_prefix_template = log_line_prefix_template
self._worker_watchdog: Optional[timer.FileTimerServer] = None
self._logs_specs = logs_specs
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
watchdog_enabled = os.getenv(enable_watchdog_env_name)
watchdog_file_env_name = TORCHELASTIC_TIMER_FILE
watchdog_file_path = os.getenv(watchdog_file_env_name)
if watchdog_enabled is not None and str(watchdog_enabled) == "1":
if watchdog_file_path is