# Neural networks


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

## Encoders

### Convolutional

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/nn/core.py#L17"
target="_blank" style="float:right; font-size:smaller">source</a>

### CircularConv1D

``` python

def CircularConv1D(
    features:int, kernel_size:int, strides:int=1, use_bias:bool=True,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

```

*Convolutional layer with circular padding (used for Q. system with
PBC)*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/nn/core.py#L41"
target="_blank" style="float:right; font-size:smaller">source</a>

### CNNcirc1D

``` python

def EncoderCNNcirc1D(
    latent_dim:int, num_conv_layers:int=2, kernel_size:int=3, strides:int=1, conv_features:int=16,
    dense_features:int=64, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

```

*Encoder with convolutional layer with circular padding (used for 1d Q.
system with PBC)*

Args:

    latent_dim: int
    num_conv_layers: int
    kernel_size: int
    strides: int
    conv_features: int
    dense_features: int

Returns:

    mean: (batch, latent_dim)
    logvar: (batch, latent_dim)

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/nn/core.py#L95"
target="_blank" style="float:right; font-size:smaller">source</a>

### CNN2D

``` python

def EncoderCNN2D(
    latent_dim:int, num_conv_layers:int=2, kernel_size:int=(2, 2), strides:int=(1, 1), conv_features:int=16,
    dense_features:int=64, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

```

*Encoder with convolutional layer used for 2d systems*

Args:

    latent_dim: int
    num_conv_layers: int
    lattice_topology: jnp.ndarray describing the topology (and ordering) of the system
    kernel_size: int
    strides: int
    conv_features: int
    dense_features: int

Returns:

    mean: (batch, latent_dim)
    logvar: (batch, latent_dim)

### Transformer

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/nn/core.py#L186"
target="_blank" style="float:right; font-size:smaller">source</a>

### Embedding

``` python

def Embedding(
    d_model:int, data_type:str, local_dimension:int=2,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

```

*Embedding for the various data_type*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/nn/core.py#L178"
target="_blank" style="float:right; font-size:smaller">source</a>

### shift_right

``` python

def shift_right(
    x, start_token
):

```

*Shift input to the right by one position, used for the autoregressivity
of the decoder*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/nn/core.py#L153"
target="_blank" style="float:right; font-size:smaller">source</a>

### PositionalEncoding

``` python

def PositionalEncoding(
    d_model:int, max_len:int, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>,
    name:Optional=None
)->None:

```

*Positional encoding module.*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/nn/core.py#L234"
target="_blank" style="float:right; font-size:smaller">source</a>

### Transformer_encoder

``` python

def Transformer_encoder(
    num_heads:int=2, d_model:int=8, num_layers:int=3, latent_dim:int=5, data_type:str='discrete',
    local_dimension:int=2, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

```

*Encoder based on the transformer architecture*

Args:

    num_heads: int
    d_model: int
    num_layers: int
    latent_dim: int
    data_type: str
    local_dimension: int

Returns:

    mean: (batch, latent_dim)
    logvar: (batch, latent_dim)

## Decoders

### dense

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/nn/core.py#L310"
target="_blank" style="float:right; font-size:smaller">source</a>

### MaskedDense1D

``` python

def MaskedDense1D(
    features:int, exclude:bool=False, dtype:Any=float32,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

```

*Masked linear layer*

``` python
#check the autoregressivity
decoder = ARNNDense()
x = jnp.ones((100,10))
z = jnp.ones((x.shape[0],5))
inputs = (z, x)
key = jax.random.PRNGKey(3124)
params_decoder = decoder.init(key, inputs)

cp1 = decoder.apply(params_decoder, inputs)
x2 = x.at[0,2].set(0)
inputs = (z, x2)
cp2 = decoder.apply(params_decoder, inputs)
cp1[0]==cp2[0] #the first 3 cp should be the same
```

    Array([[ True,  True],
           [ True,  True],
           [ True,  True],
           [False, False],
           [False, False],
           [False, False],
           [False, False],
           [False, False],
           [False, False],
           [False, False]], dtype=bool)

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/nn/core.py#L361"
target="_blank" style="float:right; font-size:smaller">source</a>

### ARNNDense

``` python

def ARNNDense(
    num_layers:int=3, features:int=36, local_dimension:int=2,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

```

*Autoregressive neural network with dense layers*

Args:

    num_layers: int
    features: int
    local_dimension: int

Returns:

    log_cp: (batch, seq_len, local_dimension)

### Transformer

``` python
#check the autoregressivity
decoder = Transformer_decoder(num_heads = 1,
            d_model = 4,
            num_layers = 1,
            data_type = 'discrete',
            local_dimension = 2)
            
x = jnp.ones((100,10))
z = jnp.ones((x.shape[0],5))
inputs = (z, x)
key = jax.random.PRNGKey(3124)
params_decoder = decoder.init(key, inputs)

cp1 = decoder.apply(params_decoder, inputs)
x2 = x.at[0,2].set(0)
inputs = (z, x2)
cp2 = decoder.apply(params_decoder, inputs)
cp1[0]==cp2[0] #the first 3 cp should be the same
```

    Array([[ True,  True],
           [ True,  True],
           [ True,  True],
           [False, False],
           [False, False],
           [False, False],
           [False, False],
           [False, False],
           [False, False],
           [False, False]], dtype=bool)

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/nn/core.py#L407"
target="_blank" style="float:right; font-size:smaller">source</a>

### Transformer_decoder

``` python

def Transformer_decoder(
    num_heads:int=4, d_model:int=32, num_layers:int=3, data_type:str='discrete', local_dimension:int=2,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None
)->None:

```

*Decoder based on the transformer architecture*

Args:

    num_heads: int
    d_model: int
    num_layers: int
    data_type: str
    local_dimension: int

Returns:

    log_cp: (batch, seq_len, local_dimension)
