Neural networks

This module contains the implementation of neural network models

Overview

This module contains the core neural network building blocks used by qdisc.vae.core. It provides encoder and decoder architectures for quantum data, including:

  • circular and 2D convolutional encoders
  • transformer-based encoder and decoder blocks
  • autoregressive dense decoders
  • embedding and positional encoding utilities for discrete, shadow, and hybrid inputs

Encoders

This section defines encoder architectures that map input quantum data to the VAE latent space. It includes convolutional encoders for periodic 1D and 2D systems as well as a transformer-based encoder.


source

CircularConv1D


def CircularConv1D(
    features:int, kernel_size:int, strides:int=1, use_bias:bool=True,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f74545e5280>, name:Optional=None
)->None:

Convolutional layer with circular padding (used for Q. system with PBC)


source

CNNcirc1D


def EncoderCNNcirc1D(
    latent_dim:int, num_conv_layers:int=2, kernel_size:int=3, strides:int=1, conv_features:int=16,
    dense_features:int=64, parent:Union=<flax.linen.module._Sentinel object at 0x7f74545e5280>, name:Optional=None
)->None:

Encoder with convolutional layer with circular padding (used for 1d Q. system with PBC)

Args:

latent_dim: int
num_conv_layers: int
kernel_size: int
strides: int
conv_features: int
dense_features: int

Returns:

mean: (batch, latent_dim)
logvar: (batch, latent_dim)

source

CNN2D


def EncoderCNN2D(
    latent_dim:int, num_conv_layers:int=2, kernel_size:int=(2, 2), strides:int=(1, 1), conv_features:int=16,
    dense_features:int=64, parent:Union=<flax.linen.module._Sentinel object at 0x7f74545e5280>, name:Optional=None
)->None:

Encoder with convolutional layer used for 2d systems

Args:

latent_dim: int
num_conv_layers: int
lattice_topology: jnp.ndarray describing the topology (and ordering) of the system
kernel_size: int
strides: int
conv_features: int
dense_features: int

Returns:

mean: (batch, latent_dim)
logvar: (batch, latent_dim)

source

Embedding


def Embedding(
    d_model:int, data_type:str, local_dimension:int=2,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f74545e5280>, name:Optional=None
)->None:

Embedding for the various data_type


source

shift_right


def shift_right(
    x, start_token
):

Shift input to the right by one position, used for the autoregressivity of the decoder


source

PositionalEncoding


def PositionalEncoding(
    d_model:int, max_len:int, parent:Union=<flax.linen.module._Sentinel object at 0x7f74545e5280>,
    name:Optional=None
)->None:

Positional encoding module.


source

Transformer_encoder


def Transformer_encoder(
    num_heads:int=2, d_model:int=8, num_layers:int=3, latent_dim:int=5, data_type:str='discrete',
    local_dimension:int=2, parent:Union=<flax.linen.module._Sentinel object at 0x7f74545e5280>, name:Optional=None
)->None:

Encoder based on the transformer architecture

Args:

num_heads: int
d_model: int
num_layers: int
latent_dim: int
data_type: str
local_dimension: int

Returns:

mean: (batch, latent_dim)
logvar: (batch, latent_dim)

Decoders

This section contains decoder architectures for recovering discrete, shadow, or hybrid quantum data from latent variables. It includes autoregressive dense decoders and a transformer-based decoder.


source

MaskedDense1D


def MaskedDense1D(
    features:int, exclude:bool=False, dtype:Any=float32,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f74545e5280>, name:Optional=None
)->None:

Masked linear layer

#check the autoregressivity
decoder = ARNNDense()
x = jnp.ones((100,10))
z = jnp.ones((x.shape[0],5))
inputs = (z, x)
key = jax.random.PRNGKey(3124)
params_decoder = decoder.init(key, inputs)

cp1 = decoder.apply(params_decoder, inputs)
x2 = x.at[0,2].set(0)
inputs = (z, x2)
cp2 = decoder.apply(params_decoder, inputs)
cp1[0]==cp2[0] #the first 3 cp should be the same
Array([[ True,  True],
       [ True,  True],
       [ True,  True],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False]], dtype=bool)

source

ARNNDense


def ARNNDense(
    num_layers:int=3, features:int=36, local_dimension:int=2,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f74545e5280>, name:Optional=None
)->None:

Autoregressive neural network with dense layers

Args:

num_layers: int
features: int
local_dimension: int

Returns:

log_cp: (batch, seq_len, local_dimension)
#check the autoregressivity
decoder = Transformer_decoder(num_heads = 1,
            d_model = 4,
            num_layers = 1,
            data_type = 'discrete',
            local_dimension = 2)
            
x = jnp.ones((100,10))
z = jnp.ones((x.shape[0],5))
inputs = (z, x)
key = jax.random.PRNGKey(3124)
params_decoder = decoder.init(key, inputs)

cp1 = decoder.apply(params_decoder, inputs)
x2 = x.at[0,2].set(0)
inputs = (z, x2)
cp2 = decoder.apply(params_decoder, inputs)
cp1[0]==cp2[0] #the first 3 cp should be the same
Array([[ True,  True],
       [ True,  True],
       [ True,  True],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False],
       [False, False]], dtype=bool)

source

Transformer_decoder


def Transformer_decoder(
    num_heads:int=4, d_model:int=32, num_layers:int=3, data_type:str='discrete', local_dimension:int=2,
    parent:Union=<flax.linen.module._Sentinel object at 0x7f74545e5280>, name:Optional=None
)->None:

Decoder based on the transformer architecture

Args:

num_heads: int
d_model: int
num_layers: int
data_type: str
local_dimension: int

Returns:

log_cp: (batch, seq_len, local_dimension)