jax.example_libraries.optimizers
模块#
如何使用 JAX 编写优化器的示例。
您可能不打算导入此模块!此库中的优化器仅作为示例。如果您正在寻找一个功能齐全的优化器库,请考虑 Optax。
此模块包含一些方便的优化器定义,特别是初始化和更新函数,可以与ndarrays或任意嵌套的tuple/list/dicts的ndarrays一起使用。
一个优化器被建模为一个 (init_fun, update_fun, get_params)
函数的三个元组,其中组件函数具有以下签名:
init_fun(params)
Args:
params: pytree representing the initial parameters.
Returns:
A pytree representing the initial optimizer state, which includes the
initial parameters and may also include auxiliary values like initial
momentum. The optimizer state pytree structure generally differs from that
of `params`.
update_fun(step, grads, opt_state)
Args:
step: integer representing the step index.
grads: a pytree with the same structure as `get_params(opt_state)`
representing the gradients to be used in updating the optimizer state.
opt_state: a pytree representing the optimizer state to be updated.
Returns:
A pytree with the same structure as the `opt_state` argument representing
the updated optimizer state.
get_params(opt_state)
Args:
opt_state: pytree representing an optimizer state.
Returns:
A pytree representing the parameters extracted from `opt_state`, such that
the invariant `params == get_params(init_fun(params))` holds true.
请注意,优化器实现具有很大的灵活性,表现为 opt_state:它只需要是一个 JaxTypes 的 pytree(以便它可以传递给在 api.py 中定义的 JAX 变换),并且它必须可以被 update_fun 和 get_params 使用。
示例用法:
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)
def step(step, opt_state):
value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
opt_state = opt_update(step, grads, opt_state)
return value, opt_state
for i in range(num_steps):
value, opt_state = step(i, opt_state)
- class jax.example_libraries.optimizers.JoinPoint(subtree)[源代码][源代码]#
基类:
object
标记两个连接(嵌套)的 pytrees 之间的边界。
- class jax.example_libraries.optimizers.Optimizer(init_fn, update_fn, params_fn)[源代码][源代码]#
基类:
NamedTuple
- 参数:
init_fn (InitFn)
update_fn (UpdateFn)
params_fn (ParamsFn)
- init_fn: InitFn#
字段编号 0 的别名
- params_fn: ParamsFn#
字段编号2的别名
- update_fn: UpdateFn#
字段编号1的别名
- class jax.example_libraries.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)#
基类:
tuple
- packed_state#
字段编号 0 的别名
- subtree_defs#
字段编号2的别名
- tree_def#
字段编号1的别名
- jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)[源代码][源代码]#
构建 Adagrad 的优化器三元组。
在线学习和随机优化的自适应子梯度方法: http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
- 参数:
step_size – 正标量,或表示步长计划的回调函数,该计划将迭代索引映射为正标量。
momentum – 可选,动量的正值标量
- 返回:
一个 (init_fun, update_fun, get_params) 三元组。
- jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)[源代码][源代码]#
构建 Adam 的优化器三元组。
- 参数:
step_size – 正标量,或表示步长计划的回调函数,该计划将迭代索引映射为正标量。
b1 – 可选,一个正标量值用于 beta_1,即第一矩估计的指数衰减率(默认 0.9)。
b2 – 可选,一个正标量值,用于 beta_2,即第二矩估计的指数衰减率(默认值为 0.999)。
eps – 可选,一个用于 epsilon 的正标量值,这是一个用于数值稳定性的微小常数(默认值为 1e-8)。
- 返回:
一个 (init_fun, update_fun, get_params) 三元组。
- jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)[源代码][源代码]#
构建 AdaMax 的优化器三元组(基于无穷范数的 Adam 变体)。
- 参数:
step_size – 正标量,或表示步长计划的回调函数,该计划将迭代索引映射为正标量。
b1 – 可选,一个正标量值用于 beta_1,即第一矩估计的指数衰减率(默认 0.9)。
b2 – 可选,一个正标量值,用于 beta_2,即第二矩估计的指数衰减率(默认值为 0.999)。
eps – 可选,一个用于 epsilon 的正标量值,这是一个用于数值稳定性的微小常数(默认值为 1e-8)。
- 返回:
一个 (init_fun, update_fun, get_params) 三元组。
- jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)[源代码][源代码]#
将存储为数组的 pytree 的剪切梯度剪切到最大范数 max_norm。
- jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)[源代码][源代码]#
- jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)[源代码][源代码]#
- 参数:
scalar_or_schedule (float | Schedule)
- 返回类型:
Schedule
- jax.example_libraries.optimizers.momentum(step_size, mass)[源代码][源代码]#
构建带有动量的SGD优化器三元组。
- 参数:
step_size (Schedule) – 正标量,或表示步长计划的回调函数,该计划将迭代索引映射为正标量。
mass (float) – 表示动量系数的正标量。
- 返回:
一个 (init_fun, update_fun, get_params) 三元组。
- jax.example_libraries.optimizers.nesterov(step_size, mass)[源代码][源代码]#
构建带有Nesterov动量的SGD优化器三元组。
- 参数:
step_size (Schedule) – 正标量,或表示步长计划的回调函数,该计划将迭代索引映射为正标量。
mass (float) – 表示动量系数的正标量。
- 返回:
一个 (init_fun, update_fun, get_params) 三元组。
- jax.example_libraries.optimizers.optimizer(opt_maker)[源代码][源代码]#
用于使针对数组定义的优化器推广到容器的装饰器。
使用这个装饰器,你可以编写 init、update 和 get_params 函数,它们各自仅对单个数组进行操作,并将其转换为对参数的 pytrees 进行操作的相应函数。参见 optimizers.py 中定义的优化器示例。
- 参数:
opt_maker (Callable[..., tuple[Callable[[Params], State], Callable[[Step, Updates, Params], Params], Callable[[State], Params]]]) –
一个返回
(init_fun, update_fun, get_params)
函数三元组的函数,这些函数可能仅适用于 ndarrays。init_fun :: ndarray -> OptStatePytree ndarray update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray get_params :: OptStatePytree ndarray -> ndarray
- 返回:
一个
(init_fun, update_fun, get_params)
函数三元组,它们按照以下方式处理任意 pytrees:- 返回类型:
Callable[…, Optimizer]
- jax.example_libraries.optimizers.pack_optimizer_state(marked_pytree)[源代码][源代码]#
将标记的 pytree 转换为 OptimizerState。
unpack_optimizer_state 的逆操作。将一个带有标记的 pytree 转换回 OptimizerState,其中外部 pytree 的叶子表示为 JoinPoints。此函数旨在在反序列化优化器状态时使用。
- 参数:
marked_pytree – 一个包含 JoinPoint 叶子的 pytree,这些叶子持有更多的 pytree。
- 返回:
与输入参数等效的 OptimizerState。
- jax.example_libraries.optimizers.piecewise_constant(boundaries, values)[源代码][源代码]#
- 参数:
boundaries (Any)
values (Any)
- jax.example_libraries.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)[源代码][源代码]#
- jax.example_libraries.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)[源代码][源代码]#
构建RMSProp的优化器三元组。
- 参数:
step_size – 正标量,或表示步长调度的一个可调用对象,该调度将迭代索引映射到正标量。gamma: 衰减参数。eps: 欧米伽参数。
- 返回:
一个 (init_fun, update_fun, get_params) 三元组。
- jax.example_libraries.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)[源代码][源代码]#
构建带有动量的 RMSProp 优化器三元组。
这个优化器与rmsprop优化器分开,因为它需要跟踪额外的参数。
- 参数:
step_size – 正标量,或表示步长计划的回调函数,该计划将迭代索引映射为正标量。
gamma – 衰减参数。
eps – Epsilon 参数。
momentum – 动量参数。
- 返回:
一个 (init_fun, update_fun, get_params) 三元组。
- jax.example_libraries.optimizers.sgd(step_size)[源代码][源代码]#
构建随机梯度下降的优化器三元组。
- 参数:
step_size – 正标量,或表示步长计划的回调函数,该计划将迭代索引映射为正标量。
- 返回:
一个 (init_fun, update_fun, get_params) 三元组。
- jax.example_libraries.optimizers.sm3(step_size, momentum=0.9)[源代码][源代码]#
为 SM3 构建优化器三元组。
大规模学习的内存高效自适应优化。 https://arxiv.org/abs/1901.11150
- 参数:
step_size – 正标量,或表示步长计划的回调函数,该计划将迭代索引映射为正标量。
momentum – 可选,动量的正值标量
- 返回:
一个 (init_fun, update_fun, get_params) 三元组。