jax.lax.switch#
- jax.lax.switch(index, branches, *operands, operand=<object object>)[源代码][源代码]#
应用由
index
给出的branches
中的一个。如果
index
超出范围,它将被限制在范围内。具有以下 Python 的语义:
def switch(index, branches, *operands): index = clamp(0, index, len(branches) - 1) return branches[index](*operands)
在内部,这封装了 XLA 的 Conditional 操作符。然而,当使用
vmap()
转换以对一批谓词进行操作时,cond
被转换为select()
。- 参数:
index – 整数标量类型,指示应用哪个分支函数。
branches (Sequence[Callable]) – 基于
index
应用的函数序列(A -> B)。operands – 操作数 (A) 输入到应用的任何分支。
- 返回:
基于
index
选择的分支的branch(*operands)
的值 (B)。