Overview
This module contains the core neural network building blocks used by qdisc.vae.core. It provides encoder and decoder architectures for quantum data, including:
circular and 2D convolutional encoders
transformer-based encoder and decoder blocks
autoregressive dense decoders
embedding and positional encoding utilities for discrete, shadow, and hybrid inputs
Encoders
This section defines encoder architectures that map input quantum data to the VAE latent space. It includes convolutional encoders for periodic 1D and 2D systems as well as a transformer-based encoder.
source
CircularConv1D
def CircularConv1D(
features:int , kernel_size:int , strides:int = 1 , use_bias:bool = True ,
parent:Union=< flax.linen.module._Sentinel object at 0x7f74545e5280 > , name:Optional= None
)-> None :
Convolutional layer with circular padding (used for Q. system with PBC)
source
CNNcirc1D
def EncoderCNNcirc1D(
latent_dim:int , num_conv_layers:int = 2 , kernel_size:int = 3 , strides:int = 1 , conv_features:int = 16 ,
dense_features:int = 64 , parent:Union=< flax.linen.module._Sentinel object at 0x7f74545e5280 > , name:Optional= None
)-> None :
Encoder with convolutional layer with circular padding (used for 1d Q. system with PBC)
Args:
latent_dim: int
num_conv_layers: int
kernel_size: int
strides: int
conv_features: int
dense_features: int
Returns:
mean: (batch, latent_dim)
logvar: (batch, latent_dim)
source
CNN2D
def EncoderCNN2D(
latent_dim:int , num_conv_layers:int = 2 , kernel_size:int = (2 , 2 ), strides:int = (1 , 1 ), conv_features:int = 16 ,
dense_features:int = 64 , parent:Union=< flax.linen.module._Sentinel object at 0x7f74545e5280 > , name:Optional= None
)-> None :
Encoder with convolutional layer used for 2d systems
Args:
latent_dim: int
num_conv_layers: int
lattice_topology: jnp.ndarray describing the topology (and ordering) of the system
kernel_size: int
strides: int
conv_features: int
dense_features: int
Returns:
mean: (batch, latent_dim)
logvar: (batch, latent_dim)
source
Embedding
def Embedding(
d_model:int , data_type:str , local_dimension:int = 2 ,
parent:Union=< flax.linen.module._Sentinel object at 0x7f74545e5280 > , name:Optional= None
)-> None :
Embedding for the various data_type
source
shift_right
def shift_right(
x, start_token
):
Shift input to the right by one position, used for the autoregressivity of the decoder
source
PositionalEncoding
def PositionalEncoding(
d_model:int , max_len:int , parent:Union=< flax.linen.module._Sentinel object at 0x7f74545e5280 > ,
name:Optional= None
)-> None :
Positional encoding module.
source
Decoders
This section contains decoder architectures for recovering discrete, shadow, or hybrid quantum data from latent variables. It includes autoregressive dense decoders and a transformer-based decoder.
source
MaskedDense1D
def MaskedDense1D(
features:int , exclude:bool = False , dtype:Any= float32,
parent:Union=< flax.linen.module._Sentinel object at 0x7f74545e5280 > , name:Optional= None
)-> None :
Masked linear layer
#check the autoregressivity
decoder = ARNNDense()
x = jnp.ones((100 ,10 ))
z = jnp.ones((x.shape[0 ],5 ))
inputs = (z, x)
key = jax.random.PRNGKey(3124 )
params_decoder = decoder.init(key, inputs)
cp1 = decoder.apply (params_decoder, inputs)
x2 = x.at[0 ,2 ].set (0 )
inputs = (z, x2)
cp2 = decoder.apply (params_decoder, inputs)
cp1[0 ]== cp2[0 ] #the first 3 cp should be the same
Array([[ True, True],
[ True, True],
[ True, True],
[False, False],
[False, False],
[False, False],
[False, False],
[False, False],
[False, False],
[False, False]], dtype=bool)
source
ARNNDense
def ARNNDense(
num_layers:int = 3 , features:int = 36 , local_dimension:int = 2 ,
parent:Union=< flax.linen.module._Sentinel object at 0x7f74545e5280 > , name:Optional= None
)-> None :
Autoregressive neural network with dense layers
Args:
num_layers: int
features: int
local_dimension: int
Returns:
log_cp: (batch, seq_len, local_dimension)
#check the autoregressivity
decoder = Transformer_decoder(num_heads = 1 ,
d_model = 4 ,
num_layers = 1 ,
data_type = 'discrete' ,
local_dimension = 2 )
x = jnp.ones((100 ,10 ))
z = jnp.ones((x.shape[0 ],5 ))
inputs = (z, x)
key = jax.random.PRNGKey(3124 )
params_decoder = decoder.init(key, inputs)
cp1 = decoder.apply (params_decoder, inputs)
x2 = x.at[0 ,2 ].set (0 )
inputs = (z, x2)
cp2 = decoder.apply (params_decoder, inputs)
cp1[0 ]== cp2[0 ] #the first 3 cp should be the same
Array([[ True, True],
[ True, True],
[ True, True],
[False, False],
[False, False],
[False, False],
[False, False],
[False, False],
[False, False],
[False, False]], dtype=bool)
source