jax.experimental.pallas.num_programs# jax.experimental.pallas.num_programs(axis)[源代码][源代码]# 返回网格沿给定轴的大小。 参数: axis (int) 返回类型: int | jax.Array