jax.random.wrap_key_data

jax.random.wrap_key_data#

jax.random.wrap_key_data(key_bits_array, *, impl=None)[源代码][源代码]#

将一组关键数据位包装成一个PRNG密钥数组。

参数:
  • key_bits_array (Array) – 一个 uint32 数组,其尾随形状对应于由 impl 指定的 PRNG 实现的键形状。

  • impl (PRNGSpecDesc | None) – 可选,指定一个PRNG实现,如 random.key

返回:

一个PRNG键数组,其dtype是``jax.dtypes.prng_key``的子dtype,对应于``impl``,其形状等于``key_bits_array.shape``的前导形状,直到键位维度。