Neural networks

This module contains the implementation of neural network models

Encoders

Convolutional


source

CircularConv1D


def CircularConv1D(
    features:int, kernel_size:int, strides:int=1, use_bias:bool=True,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

Convolutional layer with circular padding (used for Q. system with PBC)


source

CNNcirc1D


def EncoderCNNcirc1D(
    latent_dim:int, num_conv_layers:int=2, kernel_size:int=3, strides:int=1, conv_features:int=16,
    dense_features:int=64, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

Encoder with convolutional layer with circular padding (used for 1d Q. system with PBC)

Args:

latent_dim: int
num_conv_layers: int
kernel_size: int
strides: int
conv_features: int
dense_features: int

Returns:

mean: (batch, latent_dim)
logvar: (batch, latent_dim)

source

CNN2D


def EncoderCNN2D(
    latent_dim:int, num_conv_layers:int=2, kernel_size:int=(2, 2), strides:int=(1, 1), conv_features:int=16,
    dense_features:int=64, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

Encoder with convolutional layer used for 2d systems

Args:

latent_dim: int
num_conv_layers: int
lattice_topology: jnp.ndarray describing the topology (and ordering) of the system
kernel_size: int
strides: int
conv_features: int
dense_features: int

Returns:

mean: (batch, latent_dim)
logvar: (batch, latent_dim)

Transformer


source

Embedding


def Embedding(
    d_model:int, data_type:str, local_dimension:int=2,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

Embedding for the various data_type


source

shift_right


def shift_right(
    x, start_token
):

Shift input to the right by one position, used for the autoregressivity of the decoder


source

PositionalEncoding


def PositionalEncoding(
    d_model:int, max_len:int, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>,
    name:Optional=None
)->None:

Positional encoding module.


source

Transformer_encoder


def Transformer_encoder(
    num_heads:int=2, d_model:int=8, num_layers:int=3, latent_dim:int=5, data_type:str='discrete',
    local_dimension:int=2, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

Encoder based on the transformer architecture

Args:

num_heads: int
d_model: int
num_layers: int
latent_dim: int
data_type: str
local_dimension: int

Returns:

mean: (batch, latent_dim)
logvar: (batch, latent_dim)

Decoders

dense


source

MaskedDense1D


def MaskedDense1D(
    features:int, exclude:bool=False, dtype:Any=float32,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

Masked linear layer

#check the autoregressivity
decoder = ARNNDense()
x = jnp.ones((100,10))
z = jnp.ones((x.shape[0],5))
inputs = (z, x)
key = jax.random.PRNGKey(3124)
params_decoder = decoder.init(key, inputs)

cp1 = decoder.apply(params_decoder, inputs)
x2 = x.at[0,2].set(0)
inputs = (z, x2)
cp2 = decoder.apply(params_decoder, inputs)
cp1[0]==cp2[0] #the first 3 cp should be the same
Array([[ True,  True],
       [ True,  True],
       [ True,  True],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False]], dtype=bool)

source

ARNNDense


def ARNNDense(
    num_layers:int=3, features:int=36, local_dimension:int=2,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

Autoregressive neural network with dense layers

Args:

num_layers: int
features: int
local_dimension: int

Returns:

log_cp: (batch, seq_len, local_dimension)

Transformer

#check the autoregressivity
decoder = Transformer_decoder(num_heads = 1,
            d_model = 4,
            num_layers = 1,
            data_type = 'discrete',
            local_dimension = 2)
            
x = jnp.ones((100,10))
z = jnp.ones((x.shape[0],5))
inputs = (z, x)
key = jax.random.PRNGKey(3124)
params_decoder = decoder.init(key, inputs)

cp1 = decoder.apply(params_decoder, inputs)
x2 = x.at[0,2].set(0)
inputs = (z, x2)
cp2 = decoder.apply(params_decoder, inputs)
cp1[0]==cp2[0] #the first 3 cp should be the same
Array([[ True,  True],
       [ True,  True],
       [ True,  True],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False]], dtype=bool)

source

Transformer_decoder


def Transformer_decoder(
    num_heads:int=4, d_model:int=32, num_layers:int=3, data_type:str='discrete', local_dimension:int=2,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

Decoder based on the transformer architecture

Args:

num_heads: int
d_model: int
num_layers: int
data_type: str
local_dimension: int

Returns:

log_cp: (batch, seq_len, local_dimension)