备注
Ray 2.10.0 引入了 RLlib 的“新 API 栈”的 alpha 阶段。Ray 团队计划将算法、示例脚本和文档迁移到新的代码库中,从而在 Ray 3.0 之前的后续小版本中逐步替换“旧 API 栈”(例如,ModelV2、Policy、RolloutWorker)。
然而,请注意,到目前为止,只有 PPO(单代理和多代理)和 SAC(仅单代理)支持“新 API 堆栈”,并且默认情况下继续使用旧 API 运行。您可以继续使用现有的自定义(旧堆栈)类。
请参阅此处 以获取有关如何使用新API堆栈的更多详细信息。
连接器 (Beta)#
连接器是处理给定RL策略的输入和输出转换的组件,目的是提高 RLlib策略检查点 的耐用性和可维护性。
RLlib 算法通常需要一个或多个 用户环境 和 策略 (通常是一个神经网络)。
从环境中观察到的数据通常在到达策略之前会经过多个预处理步骤,而策略的输出在用于控制环境中的特定代理之前也会经过多次转换。
通过在连接器框架下整合这些转换,RLlib 的用户将能够:
在不恢复 RLlib 算法相关训练逻辑的情况下,恢复和部署单个 RLlib 策略。
确保政策比它们所训练的算法更为持久。
允许策略适应不同版本的环境。
使用 RLlib 策略进行推理,无需担心确切的轨迹视图要求或状态输入。
可以通过将 enable_connectors
参数设置为 True
来启用连接器,使用 AlgorithmConfig.env_runners()
API。
关键概念#
我们有两类连接器。第一类是 AgentConnector
,用于将环境中的观察数据转换为策略。第二类是 ActionConnector
,用于将策略的输出转换为动作。
AgentConnector#
AgentConnectors
负责将环境观察数据转换为策略能够理解的格式(例如,将复杂的嵌套观察结果展平为扁平的张量)。高级API包括:
class AgentConnector(Connector):
def __call__(
self, acd_list: List[AgentConnectorDataType]
) -> List[AgentConnectorDataType]:
...
def transform(
self, ac_data: AgentConnectorDataType
) -> AgentConnectorDataType:
...
def reset(self, env_id: str):
...
def on_policy_output(self, output: ActionConnectorDataType):
...
AgentConnector 操作于一个观察数据列表。该列表是通过将映射到同一策略的代理的观察结果分组在一起构建的。
这种设置对于某些多代理使用场景非常有用,其中可能需要根据其他代理的数据来修改单个观察结果。如果用户需要构建元观察结果,例如,从单个代理观察结果构建一个图作为策略的输入,这也非常有用。
为了方便,如果一个 AgentConnector
不操作完整的代理数据列表,可以通过简单地重写 transform()
API 来实现。
AgentConnectors 还提供了一种方法,用于记录当前时间步(在通过 ActionConnectors 进行转换之前)策略的输出,以便稍后在下一个时间步中用于推理。这是通过 on_policy_output()
API 调用完成的,当您的策略是循环网络、注意力网络或自回归模型时,这非常有用。
ActionConnector#
ActionConnector
有一个更简单的 API,它针对单个操作进行操作:
class ActionConnector(Connector):
def __call__(
self, ac_data: ActionConnectorDataType
) -> ActionConnectorDataType:
...
def transform(
self, ac_data: ActionConnectorDataType
) -> ActionConnectorDataType:
...
在这种情况下,__call__
和 transform
是等价的。用户可以选择重写任一API来实现一个ActionConnector。
常见数据类型#
AgentConnectorDataType#
通过 AgentConnector
的每个代理的观察数据采用 AgentConnectorDataType
格式。
@OldAPIStack
class AgentConnectorDataType:
"""Data type that is fed into and yielded from agent connectors.
Args:
env_id: ID of the environment.
agent_id: ID to help identify the agent from which the data is received.
data: A payload (``data``). With RLlib's default sampler, the payload
is a dictionary of arbitrary data columns (obs, rewards, terminateds,
truncateds, etc).
"""
def __init__(self, env_id: str, agent_id: str, data: Any):
self.env_id = env_id
self.agent_id = agent_id
self.data = data
AgentConnectorsOutput#
RLlib 默认代理连接器管道的输出格式为 AgentConnectorsOutput
。
@OldAPIStack
class AgentConnectorsOutput:
"""Final output data type of agent connectors.
Args are populated depending on the AgentConnector settings.
The branching happens in ViewRequirementAgentConnector.
Args:
raw_dict: The raw input dictionary that sampler can use to
build episodes and training batches.
This raw dict also gets passed into ActionConnectors in case
it contains data useful for action adaptation (e.g. action masks).
sample_batch: The SampleBatch that can be immediately used for
querying the policy for next action.
"""
def __init__(
self, raw_dict: Dict[str, TensorStructType], sample_batch: "SampleBatch"
):
self.raw_dict = raw_dict
self.sample_batch = sample_batch
请注意,除了可以用于运行策略前向传递的已处理样本批次外,AgentConnectorsOutput
还提供了原始的原始输入字典,因为它有时包含下游处理所需的数据(例如,动作掩码)。
ActionConnectorDataType#
ActionConnectorDataType
是 ActionConnector
处理的数据类型。它基本上包括环境和代理ID、input_dict以及 PolicyOutputType
。原始输入字典在需要某些数据字段来适应动作输出时,例如动作掩码,会提供给动作连接器。
@OldAPIStack
class ActionConnectorDataType:
"""Data type that is fed into and yielded from agent connectors.
Args:
env_id: ID of the environment.
agent_id: ID to help identify the agent from which the data is received.
input_dict: Input data that was passed into the policy.
Sometimes output must be adapted based on the input, for example
action masking. So the entire input data structure is provided here.
output: An object of PolicyOutputType. It is is composed of the
action output, the internal state output, and additional data fetches.
"""
def __init__(
self,
env_id: str,
agent_id: str,
input_dict: TensorStructType,
output: PolicyOutputType,
):
self.env_id = env_id
self.agent_id = agent_id
self.input_dict = input_dict
self.output = output
之前,RLlib 策略的用户在调用策略之前必须提供正确的观察和状态输入。有了代理连接器,这项任务会自动处理。
PolicyOutputType = Tuple[TensorStructType, StateBatches, Dict] # @OldAPIStack
高级连接器#
Lambda Connector 帮助将简单的转换函数转换为代理或动作连接器,而无需用户担心高级列表或非列表API。Lambda Connector 有单独的代理和动作版本,例如:
# An example agent connector that filters INFOS column out of
# observation data.
def filter(d: ActionConnectorDataType):
del d.data[SampleBatch.INFOS]
return d
FilterInfosColumnAgentConnector = register_lambda_agent_connector(
"FilterInfosColumnAgentConnector", filter
)
# An example action connector that scales actions output by the
# policy by a factor of 2.
ScaleActionConnector = register_lambda_action_connector(
"ScaleActionConnector",
lambda actions, states, fetches: 2 * actions, states, fetches
)
多个连接器可以组合成一个 ConnectorPipeline
,它按顺序处理所有子连接器的正确运行,并提供基本操作来修改和更新连接器的组合。
ConnectorPipeline
也有代理和动作版本:
# Example construction of an AgentConnectorPipeline.
pipeline = ActionConnectorPipeline(
ctx,
[ClipRewardAgentConnector(), ViewRequirementAgentConnector()]
)
# For demonstration purpose, we will add an ObsPreprocessorConnector
# in front of the ViewRequirementAgentConnector.
pipeline.insert_before("ViewRequirementAgentConnector", ObsPreprocessorConnector())
# Example construction of an ActionConnectorPipeline.
pipeline = ActionConnectorPipeline(
ctx,
[ConvertToNumpyConnector(), ClipActionsConnector(), ImmutableActionsConnector()]
)
# For demonstration purpose, we will drop the last ImmutableActionsConnector here.
pipeline.remove("ImmutableActionsConnector")
策略检查点#
如果启用了连接器,RLlib 将尝试以适当的序列化格式保存策略检查点,而不是依赖于 Python 的 pickle 序列化。最终目标是将以序列化的 JSON 文件保存策略检查点,以确保 RLlib 和 Python 版本之间的最大兼容性。
启用后,代理和动作连接器的配置将与检查点策略状态一起序列化并保存。这些连接器及其代表的特定转换可以通过 RLlib 提供的实用程序轻松恢复,从而简化部署和推理用例。
你可以在 这里 阅读更多关于策略检查点的内容。
服务与推理#
通过连接器基本上检查点所有在训练过程中使用的转换,策略可以很容易地恢复,而无需原始算法进行本地推理,如下面的Cartpole示例所示:
# Restore policy.
policy = Policy.from_checkpoint(
checkpoint=checkpoint_path,
policy_ids=[policy_id],
)
# Run CartPole.
env = gym.make("CartPole-v1")
env_id = "env_1"
obs, info = env.reset()
# Run for 2 episodes.
episodes = step = 0
while episodes < 2:
# Use local_policy_inference() to run inference, so we do not have to
# provide policy states or extra fetch dictionaries.
# "env_1" and "agent_1" are dummy env and agent IDs to run connectors with.
policy_outputs = local_policy_inference(
policy, env_id, "agent_1", obs, explore=False
)
assert len(policy_outputs) == 1
action, _, _ = policy_outputs[0]
print(f"episode {episodes} step {step}", obs, action)
# Step environment forward one more step.
obs, _, terminated, truncated, _ = env.step(action)
step += 1
# If the episode is done, reset the env and our connectors and start a new
# episode.
if terminated or truncated:
episodes += 1
step = 0
obs, info = env.reset()
policy.agent_connectors.reset(env_id)
RLlib 也将在不久后提供工具,使训练策略的服务器/客户端部署变得更加容易。请参阅 值得注意的待办事项。
为不同环境调整策略#
用户环境通常会经历活跃的开发迭代。使用旧版本环境训练的策略可能对更新后的环境无效。虽然环境包装器在许多情况下有助于解决这个问题,但连接器允许在不同环境中训练的策略同时协同工作。
以下是一个示例,展示了如何将针对标准Cartpole环境训练的策略适应于一个新的模拟Cartpole环境,该环境返回额外的特征并需要额外的动作输入。
class MyCartPole(gym.Env):
"""A mock CartPole environment.
Gives 2 additional observation states and takes 2 discrete actions.
"""
def __init__(self):
self._env = gym.make("CartPole-v1")
self.observation_space = gym.spaces.Box(low=-10, high=10, shape=(6,))
self.action_space = gym.spaces.MultiDiscrete(nvec=[2, 2])
def step(self, actions):
# Take the first action.
action = actions[0]
obs, reward, done, truncated, info = self._env.step(action)
# Fake additional data points to the obs.
obs = np.hstack((obs, [8.0, 6.0]))
return obs, reward, done, truncated, info
def reset(self, *, seed=None, options=None):
obs, info = self._env.reset()
return np.hstack((obs, [8.0, 6.0])), info
# Custom agent connector to drop the last 2 feature values.
def v2_to_v1_obs(data: Dict[str, TensorStructType]) -> Dict[str, TensorStructType]:
data[SampleBatch.NEXT_OBS] = data[SampleBatch.NEXT_OBS][:-2]
return data
# Agent connector that adapts observations from the new CartPole env
# into old format.
V2ToV1ObsAgentConnector = register_lambda_agent_connector(
"V2ToV1ObsAgentConnector", v2_to_v1_obs
)
# Custom action connector to add a placeholder action as the addtional action input.
def v1_to_v2_action(
actions: TensorStructType, states: StateBatches, fetches: Dict
) -> PolicyOutputType:
return np.hstack((actions, [0])), states, fetches
# Action connector that adapts action outputs from the old policy
# into new actions for the mock environment.
V1ToV2ActionConnector = register_lambda_action_connector(
"V1ToV2ActionConnector", v1_to_v2_action
)
def run(checkpoint_path, policy_id):
# Restore policy.
policy = Policy.from_checkpoint(
checkpoint=checkpoint_path,
policy_ids=[policy_id],
)
# Adapt policy trained for standard CartPole to the new env.
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
# When this policy was trained, it relied on FlattenDataAgentConnector
# to add a batch dimension to single observations.
# This is not necessary anymore, so we first remove the previously used
# FlattenDataAgentConnector.
policy.agent_connectors.remove("FlattenDataAgentConnector")
# We then add the two adapter connectors.
policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx))
policy.action_connectors.append(V1ToV2ActionConnector(ctx))
# Run CartPole.
env = MyCartPole()
obs, info = env.reset()
done = False
step = 0
while not done:
step += 1
# Use local_policy_inference() to easily run poicy with observations.
policy_outputs = local_policy_inference(policy, "env_1", "agent_1", obs)
assert len(policy_outputs) == 1
actions, _, _ = policy_outputs[0]
print(f"step {step}", obs, actions)
obs, _, done, _, _ = env.step(actions)
值得注意的待办事项#
将连接器引入离线算法。
将推出工作器过滤器迁移到连接器。
将剧集构建和训练样本收集迁移到连接器中。
在客户端-服务器远程环境中演示RLlib策略部署的示例和实用工具。