jax.lax.switch

目录

jax.lax.switch#

jax.lax.switch(index, branches, *operands, operand=<object object>)[源代码][源代码]#

应用由 index 给出的 branches 中的一个。

如果 index 超出范围,它将被限制在范围内。

具有以下 Python 的语义:

def switch(index, branches, *operands):
  index = clamp(0, index, len(branches) - 1)
  return branches[index](*operands)

在内部,这封装了 XLA 的 Conditional 操作符。然而,当使用 vmap() 转换以对一批谓词进行操作时,cond 被转换为 select()

参数:
  • index – 整数标量类型,指示应用哪个分支函数。

  • branches (Sequence[Callable]) – 基于 index 应用的函数序列(A -> B)。

  • operands – 操作数 (A) 输入到应用的任何分支。

返回:

基于 index 选择的分支的 branch(*operands) 的值 (B)。