# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stax is a small but flexible neural net specification library from scratch.
You likely do not mean to import this module! Stax is intended as an example
library only. There are a number of other much more fully-featured neural
network libraries for JAX, including `Flax`_ from Google, and `Haiku`_ from
DeepMind.
.. _Haiku: https://github.com/deepmind/dm-haiku
.. _Flax: https://github.com/google/flax
"""
import functools
import operator as op
from jax import lax
from jax import random
import jax.numpy as jnp
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
leaky_relu, selu, gelu, standardize)
from jax.nn.initializers import glorot_normal, normal, ones, zeros
# aliases for backwards compatibility
glorot = glorot_normal
randn = normal
logsoftmax = log_softmax
# Following the convention used in Keras and tf.layers, we use CamelCase for the
# names of layer constructors, like Conv and Relu, while using snake_case for
# other functions, like lax.conv and relu.
# Each layer constructor function returns an (init_fun, apply_fun) pair, where
# init_fun: takes an rng key and an input shape and returns an
# (output_shape, params) pair,
# apply_fun: takes params, inputs, and an rng key and applies the layer.
[文档]
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
"""Layer constructor function for a dense (fully-connected) layer."""
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return jnp.dot(inputs, W) + b
return init_fun, apply_fun
[文档]
def GeneralConv(dimension_numbers, out_chan, filter_shape,
strides=None, padding='VALID', W_init=None,
b_init=normal(1e-6)):
"""Layer construction function for a general convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
input_shape[lhs_spec.index('C')] if c == 'I' else
next(filter_shape_iter) for c in rhs_spec]
output_shape = lax.conv_general_shape_tuple(
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
k1, k2 = random.split(rng)
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return lax.conv_general_dilated(inputs, W, strides, padding, one, one,
dimension_numbers=dimension_numbers) + b
return init_fun, apply_fun
Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
[文档]
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
strides=None, padding='VALID', W_init=None,
b_init=normal(1e-6)):
"""Layer construction function for a general transposed-convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
input_shape[lhs_spec.index('C')] if c == 'I' else
next(filter_shape_iter) for c in rhs_spec]
output_shape = lax.conv_transpose_shape_tuple(
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
k1, k2 = random.split(rng)
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return lax.conv_transpose(inputs, W, strides, padding,
dimension_numbers=dimension_numbers) + b
return init_fun, apply_fun
Conv1DTranspose = functools.partial(GeneralConvTranspose, ('NHC', 'HIO', 'NHC'))
ConvTranspose = functools.partial(GeneralConvTranspose,
('NHWC', 'HWIO', 'NHWC'))
[文档]
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
beta_init=zeros, gamma_init=ones):
"""Layer construction function for a batch normalization layer."""
_beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
_gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
axis = (axis,) if jnp.isscalar(axis) else axis
def init_fun(rng, input_shape):
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
k1, k2 = random.split(rng)
beta, gamma = _beta_init(k1, shape), _gamma_init(k2, shape)
return input_shape, (beta, gamma)
def apply_fun(params, x, **kwargs):
beta, gamma = params
# TODO(phawkins): jnp.expand_dims should accept an axis tuple.
# (https://github.com/numpy/numpy/issues/12290)
ed = tuple(None if i in axis else slice(None) for i in range(jnp.ndim(x)))
z = standardize(x, axis, epsilon=epsilon)
if center and scale: return gamma[ed] * z + beta[ed]
if center: return z + beta[ed]
if scale: return gamma[ed] * z
return z
return init_fun, apply_fun
[文档]
def elementwise(fun, **fun_kwargs):
"""Layer that applies a scalar function elementwise on its inputs."""
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
return init_fun, apply_fun
Tanh = elementwise(jnp.tanh)
Relu = elementwise(relu)
Exp = elementwise(jnp.exp)
LogSoftmax = elementwise(log_softmax, axis=-1)
Softmax = elementwise(softmax, axis=-1)
Softplus = elementwise(softplus)
Sigmoid = elementwise(sigmoid)
Elu = elementwise(elu)
LeakyRelu = elementwise(leaky_relu)
Selu = elementwise(selu)
Gelu = elementwise(gelu)
def _pooling_layer(reducer, init_val, rescaler=None):
def PoolingLayer(window_shape, strides=None, padding='VALID', spec=None):
"""Layer construction function for a pooling layer."""
strides = strides or (1,) * len(window_shape)
rescale = rescaler(window_shape, strides, padding) if rescaler else None
if spec is None:
non_spatial_axes = 0, len(window_shape) + 1
else:
non_spatial_axes = spec.index('N'), spec.index('C')
for i in sorted(non_spatial_axes):
window_shape = window_shape[:i] + (1,) + window_shape[i:]
strides = strides[:i] + (1,) + strides[i:]
def init_fun(rng, input_shape):
padding_vals = lax.padtype_to_pads(input_shape, window_shape,
strides, padding)
ones = (1,) * len(window_shape)
out_shape = lax.reduce_window_shape_tuple(
input_shape, window_shape, strides, padding_vals, ones, ones)
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
out = lax.reduce_window(inputs, init_val, reducer, window_shape,
strides, padding)
return rescale(out, inputs, spec) if rescale else out
return init_fun, apply_fun
return PoolingLayer
MaxPool = _pooling_layer(lax.max, -jnp.inf)
SumPool = _pooling_layer(lax.add, 0.)
def _normalize_by_window_size(dims, strides, padding):
def rescale(outputs, inputs, spec):
if spec is None:
non_spatial_axes = 0, inputs.ndim - 1
else:
non_spatial_axes = spec.index('N'), spec.index('C')
spatial_shape = tuple(inputs.shape[i]
for i in range(inputs.ndim)
if i not in non_spatial_axes)
one = jnp.ones(spatial_shape, dtype=inputs.dtype)
window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
for i in sorted(non_spatial_axes):
window_sizes = jnp.expand_dims(window_sizes, i)
return outputs / window_sizes
return rescale
AvgPool = _pooling_layer(lax.add, 0., _normalize_by_window_size)
def Flatten():
"""Layer construction function for flattening all but the leading dim."""
def init_fun(rng, input_shape):
output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
return output_shape, ()
def apply_fun(params, inputs, **kwargs):
return jnp.reshape(inputs, (inputs.shape[0], -1))
return init_fun, apply_fun
Flatten = Flatten()
def Identity():
"""Layer construction function for an identity layer."""
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: inputs
return init_fun, apply_fun
Identity = Identity()
[文档]
def FanOut(num):
"""Layer construction function for a fan-out layer."""
init_fun = lambda rng, input_shape: ([input_shape] * num, ())
apply_fun = lambda params, inputs, **kwargs: [inputs] * num
return init_fun, apply_fun
def FanInSum():
"""Layer construction function for a fan-in sum layer."""
init_fun = lambda rng, input_shape: (input_shape[0], ())
apply_fun = lambda params, inputs, **kwargs: sum(inputs)
return init_fun, apply_fun
FanInSum = FanInSum()
[文档]
def FanInConcat(axis=-1):
"""Layer construction function for a fan-in concatenation layer."""
def init_fun(rng, input_shape):
ax = axis % len(input_shape[0])
concat_size = sum(shape[ax] for shape in input_shape)
out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
return jnp.concatenate(inputs, axis)
return init_fun, apply_fun
[文档]
def Dropout(rate, mode='train'):
"""Layer construction function for a dropout layer with given rate."""
def init_fun(rng, input_shape):
return input_shape, ()
def apply_fun(params, inputs, **kwargs):
rng = kwargs.get('rng', None)
if rng is None:
msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
"argument. That is, instead of `apply_fun(params, inputs)`, call "
"it like `apply_fun(params, inputs, rng)` where `rng` is a "
"PRNG key (e.g. from `jax.random.key`).")
raise ValueError(msg)
if mode == 'train':
keep = random.bernoulli(rng, rate, inputs.shape)
return jnp.where(keep, inputs / rate, 0)
else:
return inputs
return init_fun, apply_fun
# Composing layers via combinators
[文档]
def serial(*layers):
"""Combinator for composing layers in serial.
Args:
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the serial
composition of the given sequence of layers.
"""
nlayers = len(layers)
init_funs, apply_funs = zip(*layers)
def init_fun(rng, input_shape):
params = []
for init_fun in init_funs:
rng, layer_rng = random.split(rng)
input_shape, param = init_fun(layer_rng, input_shape)
params.append(param)
return input_shape, params
def apply_fun(params, inputs, **kwargs):
rng = kwargs.pop('rng', None)
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
for fun, param, rng in zip(apply_funs, params, rngs):
inputs = fun(param, inputs, rng=rng, **kwargs)
return inputs
return init_fun, apply_fun
[文档]
def parallel(*layers):
"""Combinator for composing layers in parallel.
The layer resulting from this combinator is often used with the FanOut and
FanInSum layers.
Args:
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the
parallel composition of the given sequence of layers. In particular, the
returned layer takes a sequence of inputs and returns a sequence of outputs
with the same length as the argument `layers`.
"""
nlayers = len(layers)
init_funs, apply_funs = zip(*layers)
def init_fun(rng, input_shape):
rngs = random.split(rng, nlayers)
return zip(*[init(rng, shape) for init, rng, shape
in zip(init_funs, rngs, input_shape)])
def apply_fun(params, inputs, **kwargs):
rng = kwargs.pop('rng', None)
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
return [f(p, x, rng=r, **kwargs) for f, p, x, r in zip(apply_funs, params, inputs, rngs)]
return init_fun, apply_fun
[文档]
def shape_dependent(make_layer):
"""Combinator to delay layer constructor pair until input shapes are known.
Args:
make_layer: a one-argument function that takes an input shape as an argument
(a tuple of positive integers) and returns an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the same
layer as returned by `make_layer` but with its construction delayed until
input shapes are known.
"""
def init_fun(rng, input_shape):
return make_layer(input_shape)[0](rng, input_shape)
def apply_fun(params, inputs, **kwargs):
return make_layer(inputs.shape)[1](params, inputs, **kwargs)
return init_fun, apply_fun