# Copyright 2018 The JAX Authors.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## https://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Stax is a small but flexible neural net specification library from scratch.You likely do not mean to import this module! Stax is intended as an examplelibrary only. There are a number of other much more fully-featured neuralnetwork libraries for JAX, including `Flax`_ from Google, and `Haiku`_ fromDeepMind... _Haiku: https://github.com/deepmind/dm-haiku.. _Flax: https://github.com/google/flax"""importfunctoolsimportoperatorasopfromjaximportlaxfromjaximportrandomimportjax.numpyasjnpfromjax.nnimport(relu,log_softmax,softmax,softplus,sigmoid,elu,leaky_relu,selu,gelu,standardize)fromjax.nn.initializersimportglorot_normal,normal,ones,zeros# aliases for backwards compatibilityglorot=glorot_normalrandn=normallogsoftmax=log_softmax# Following the convention used in Keras and tf.layers, we use CamelCase for the# names of layer constructors, like Conv and Relu, while using snake_case for# other functions, like lax.conv and relu.# Each layer constructor function returns an (init_fun, apply_fun) pair, where# init_fun: takes an rng key and an input shape and returns an# (output_shape, params) pair,# apply_fun: takes params, inputs, and an rng key and applies the layer.
[文档]defDense(out_dim,W_init=glorot_normal(),b_init=normal()):"""Layer constructor function for a dense (fully-connected) layer."""definit_fun(rng,input_shape):output_shape=input_shape[:-1]+(out_dim,)k1,k2=random.split(rng)W,b=W_init(k1,(input_shape[-1],out_dim)),b_init(k2,(out_dim,))returnoutput_shape,(W,b)defapply_fun(params,inputs,**kwargs):W,b=paramsreturnjnp.dot(inputs,W)+breturninit_fun,apply_fun
[文档]defGeneralConv(dimension_numbers,out_chan,filter_shape,strides=None,padding='VALID',W_init=None,b_init=normal(1e-6)):"""Layer construction function for a general convolution layer."""lhs_spec,rhs_spec,out_spec=dimension_numbersone=(1,)*len(filter_shape)strides=stridesoroneW_init=W_initorglorot_normal(rhs_spec.index('I'),rhs_spec.index('O'))definit_fun(rng,input_shape):filter_shape_iter=iter(filter_shape)kernel_shape=[out_chanifc=='O'elseinput_shape[lhs_spec.index('C')]ifc=='I'elsenext(filter_shape_iter)forcinrhs_spec]output_shape=lax.conv_general_shape_tuple(input_shape,kernel_shape,strides,padding,dimension_numbers)bias_shape=[out_chanifc=='C'else1forcinout_spec]k1,k2=random.split(rng)W,b=W_init(k1,kernel_shape),b_init(k2,bias_shape)returnoutput_shape,(W,b)defapply_fun(params,inputs,**kwargs):W,b=paramsreturnlax.conv_general_dilated(inputs,W,strides,padding,one,one,dimension_numbers=dimension_numbers)+breturninit_fun,apply_fun
[文档]defGeneralConvTranspose(dimension_numbers,out_chan,filter_shape,strides=None,padding='VALID',W_init=None,b_init=normal(1e-6)):"""Layer construction function for a general transposed-convolution layer."""lhs_spec,rhs_spec,out_spec=dimension_numbersone=(1,)*len(filter_shape)strides=stridesoroneW_init=W_initorglorot_normal(rhs_spec.index('I'),rhs_spec.index('O'))definit_fun(rng,input_shape):filter_shape_iter=iter(filter_shape)kernel_shape=[out_chanifc=='O'elseinput_shape[lhs_spec.index('C')]ifc=='I'elsenext(filter_shape_iter)forcinrhs_spec]output_shape=lax.conv_transpose_shape_tuple(input_shape,kernel_shape,strides,padding,dimension_numbers)bias_shape=[out_chanifc=='C'else1forcinout_spec]k1,k2=random.split(rng)W,b=W_init(k1,kernel_shape),b_init(k2,bias_shape)returnoutput_shape,(W,b)defapply_fun(params,inputs,**kwargs):W,b=paramsreturnlax.conv_transpose(inputs,W,strides,padding,dimension_numbers=dimension_numbers)+breturninit_fun,apply_fun
[文档]defBatchNorm(axis=(0,1,2),epsilon=1e-5,center=True,scale=True,beta_init=zeros,gamma_init=ones):"""Layer construction function for a batch normalization layer."""_beta_init=lambdarng,shape:beta_init(rng,shape)ifcenterelse()_gamma_init=lambdarng,shape:gamma_init(rng,shape)ifscaleelse()axis=(axis,)ifjnp.isscalar(axis)elseaxisdefinit_fun(rng,input_shape):shape=tuple(dfori,dinenumerate(input_shape)ifinotinaxis)k1,k2=random.split(rng)beta,gamma=_beta_init(k1,shape),_gamma_init(k2,shape)returninput_shape,(beta,gamma)defapply_fun(params,x,**kwargs):beta,gamma=params# TODO(phawkins): jnp.expand_dims should accept an axis tuple.# (https://github.com/numpy/numpy/issues/12290)ed=tuple(Noneifiinaxiselseslice(None)foriinrange(jnp.ndim(x)))z=standardize(x,axis,epsilon=epsilon)ifcenterandscale:returngamma[ed]*z+beta[ed]ifcenter:returnz+beta[ed]ifscale:returngamma[ed]*zreturnzreturninit_fun,apply_fun
[文档]defelementwise(fun,**fun_kwargs):"""Layer that applies a scalar function elementwise on its inputs."""init_fun=lambdarng,input_shape:(input_shape,())apply_fun=lambdaparams,inputs,**kwargs:fun(inputs,**fun_kwargs)returninit_fun,apply_fun
Tanh=elementwise(jnp.tanh)Relu=elementwise(relu)Exp=elementwise(jnp.exp)LogSoftmax=elementwise(log_softmax,axis=-1)Softmax=elementwise(softmax,axis=-1)Softplus=elementwise(softplus)Sigmoid=elementwise(sigmoid)Elu=elementwise(elu)LeakyRelu=elementwise(leaky_relu)Selu=elementwise(selu)Gelu=elementwise(gelu)def_pooling_layer(reducer,init_val,rescaler=None):defPoolingLayer(window_shape,strides=None,padding='VALID',spec=None):"""Layer construction function for a pooling layer."""strides=stridesor(1,)*len(window_shape)rescale=rescaler(window_shape,strides,padding)ifrescalerelseNoneifspecisNone:non_spatial_axes=0,len(window_shape)+1else:non_spatial_axes=spec.index('N'),spec.index('C')foriinsorted(non_spatial_axes):window_shape=window_shape[:i]+(1,)+window_shape[i:]strides=strides[:i]+(1,)+strides[i:]definit_fun(rng,input_shape):padding_vals=lax.padtype_to_pads(input_shape,window_shape,strides,padding)ones=(1,)*len(window_shape)out_shape=lax.reduce_window_shape_tuple(input_shape,window_shape,strides,padding_vals,ones,ones)returnout_shape,()defapply_fun(params,inputs,**kwargs):out=lax.reduce_window(inputs,init_val,reducer,window_shape,strides,padding)returnrescale(out,inputs,spec)ifrescaleelseoutreturninit_fun,apply_funreturnPoolingLayerMaxPool=_pooling_layer(lax.max,-jnp.inf)SumPool=_pooling_layer(lax.add,0.)def_normalize_by_window_size(dims,strides,padding):defrescale(outputs,inputs,spec):ifspecisNone:non_spatial_axes=0,inputs.ndim-1else:non_spatial_axes=spec.index('N'),spec.index('C')spatial_shape=tuple(inputs.shape[i]foriinrange(inputs.ndim)ifinotinnon_spatial_axes)one=jnp.ones(spatial_shape,dtype=inputs.dtype)window_sizes=lax.reduce_window(one,0.,lax.add,dims,strides,padding)foriinsorted(non_spatial_axes):window_sizes=jnp.expand_dims(window_sizes,i)returnoutputs/window_sizesreturnrescaleAvgPool=_pooling_layer(lax.add,0.,_normalize_by_window_size)defFlatten():"""Layer construction function for flattening all but the leading dim."""definit_fun(rng,input_shape):output_shape=input_shape[0],functools.reduce(op.mul,input_shape[1:],1)returnoutput_shape,()defapply_fun(params,inputs,**kwargs):returnjnp.reshape(inputs,(inputs.shape[0],-1))returninit_fun,apply_funFlatten=Flatten()defIdentity():"""Layer construction function for an identity layer."""init_fun=lambdarng,input_shape:(input_shape,())apply_fun=lambdaparams,inputs,**kwargs:inputsreturninit_fun,apply_funIdentity=Identity()
[文档]defFanOut(num):"""Layer construction function for a fan-out layer."""init_fun=lambdarng,input_shape:([input_shape]*num,())apply_fun=lambdaparams,inputs,**kwargs:[inputs]*numreturninit_fun,apply_fun
defFanInSum():"""Layer construction function for a fan-in sum layer."""init_fun=lambdarng,input_shape:(input_shape[0],())apply_fun=lambdaparams,inputs,**kwargs:sum(inputs)returninit_fun,apply_funFanInSum=FanInSum()
[文档]defFanInConcat(axis=-1):"""Layer construction function for a fan-in concatenation layer."""definit_fun(rng,input_shape):ax=axis%len(input_shape[0])concat_size=sum(shape[ax]forshapeininput_shape)out_shape=input_shape[0][:ax]+(concat_size,)+input_shape[0][ax+1:]returnout_shape,()defapply_fun(params,inputs,**kwargs):returnjnp.concatenate(inputs,axis)returninit_fun,apply_fun
[文档]defDropout(rate,mode='train'):"""Layer construction function for a dropout layer with given rate."""definit_fun(rng,input_shape):returninput_shape,()defapply_fun(params,inputs,**kwargs):rng=kwargs.get('rng',None)ifrngisNone:msg=("Dropout layer requires apply_fun to be called with a PRNG key ""argument. That is, instead of `apply_fun(params, inputs)`, call ""it like `apply_fun(params, inputs, rng)` where `rng` is a ""PRNG key (e.g. from `jax.random.key`).")raiseValueError(msg)ifmode=='train':keep=random.bernoulli(rng,rate,inputs.shape)returnjnp.where(keep,inputs/rate,0)else:returninputsreturninit_fun,apply_fun
# Composing layers via combinators
[文档]defserial(*layers):"""Combinator for composing layers in serial. Args: *layers: a sequence of layers, each an (init_fun, apply_fun) pair. Returns: A new layer, meaning an (init_fun, apply_fun) pair, representing the serial composition of the given sequence of layers. """nlayers=len(layers)init_funs,apply_funs=zip(*layers)definit_fun(rng,input_shape):params=[]forinit_funininit_funs:rng,layer_rng=random.split(rng)input_shape,param=init_fun(layer_rng,input_shape)params.append(param)returninput_shape,paramsdefapply_fun(params,inputs,**kwargs):rng=kwargs.pop('rng',None)rngs=random.split(rng,nlayers)ifrngisnotNoneelse(None,)*nlayersforfun,param,rnginzip(apply_funs,params,rngs):inputs=fun(param,inputs,rng=rng,**kwargs)returninputsreturninit_fun,apply_fun
[文档]defparallel(*layers):"""Combinator for composing layers in parallel. The layer resulting from this combinator is often used with the FanOut and FanInSum layers. Args: *layers: a sequence of layers, each an (init_fun, apply_fun) pair. Returns: A new layer, meaning an (init_fun, apply_fun) pair, representing the parallel composition of the given sequence of layers. In particular, the returned layer takes a sequence of inputs and returns a sequence of outputs with the same length as the argument `layers`. """nlayers=len(layers)init_funs,apply_funs=zip(*layers)definit_fun(rng,input_shape):rngs=random.split(rng,nlayers)returnzip(*[init(rng,shape)forinit,rng,shapeinzip(init_funs,rngs,input_shape)])defapply_fun(params,inputs,**kwargs):rng=kwargs.pop('rng',None)rngs=random.split(rng,nlayers)ifrngisnotNoneelse(None,)*nlayersreturn[f(p,x,rng=r,**kwargs)forf,p,x,rinzip(apply_funs,params,inputs,rngs)]returninit_fun,apply_fun
[文档]defshape_dependent(make_layer):"""Combinator to delay layer constructor pair until input shapes are known. Args: make_layer: a one-argument function that takes an input shape as an argument (a tuple of positive integers) and returns an (init_fun, apply_fun) pair. Returns: A new layer, meaning an (init_fun, apply_fun) pair, representing the same layer as returned by `make_layer` but with its construction delayed until input shapes are known. """definit_fun(rng,input_shape):returnmake_layer(input_shape)[0](rng,input_shape)defapply_fun(params,inputs,**kwargs):returnmake_layer(inputs.shape)[1](params,inputs,**kwargs)returninit_fun,apply_fun