source
def CircularConv1D( features:int, kernel_size:int, strides:int=1, use_bias:bool=True, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None )->None:
Convolutional layer with circular padding (used for Q. system with PBC)
def EncoderCNNcirc1D( latent_dim:int, num_conv_layers:int=2, kernel_size:int=3, strides:int=1, conv_features:int=16, dense_features:int=64, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None )->None:
Encoder with convolutional layer with circular padding (used for 1d Q. system with PBC)
Args:
latent_dim: int num_conv_layers: int kernel_size: int strides: int conv_features: int dense_features: int
Returns:
mean: (batch, latent_dim) logvar: (batch, latent_dim)
def EncoderCNN2D( latent_dim:int, num_conv_layers:int=2, kernel_size:int=(2, 2), strides:int=(1, 1), conv_features:int=16, dense_features:int=64, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None )->None:
Encoder with convolutional layer used for 2d systems
latent_dim: int num_conv_layers: int lattice_topology: jnp.ndarray describing the topology (and ordering) of the system kernel_size: int strides: int conv_features: int dense_features: int
def Embedding( d_model:int, data_type:str, local_dimension:int=2, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None )->None:
Embedding for the various data_type
def shift_right( x, start_token ):
Shift input to the right by one position, used for the autoregressivity of the decoder
def PositionalEncoding( d_model:int, max_len:int, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None )->None:
Positional encoding module.
def Transformer_encoder( num_heads:int=2, d_model:int=8, num_layers:int=3, latent_dim:int=5, data_type:str='discrete', local_dimension:int=2, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None )->None:
Encoder based on the transformer architecture
num_heads: int d_model: int num_layers: int latent_dim: int data_type: str local_dimension: int
def MaskedDense1D( features:int, exclude:bool=False, dtype:Any=float32, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None )->None:
Masked linear layer
#check the autoregressivity decoder = ARNNDense() x = jnp.ones((100,10)) z = jnp.ones((x.shape[0],5)) inputs = (z, x) key = jax.random.PRNGKey(3124) params_decoder = decoder.init(key, inputs) cp1 = decoder.apply(params_decoder, inputs) x2 = x.at[0,2].set(0) inputs = (z, x2) cp2 = decoder.apply(params_decoder, inputs) cp1[0]==cp2[0] #the first 3 cp should be the same
Array([[ True, True], [ True, True], [ True, True], [False, False], [False, False], [False, False], [False, False], [False, False], [False, False], [False, False]], dtype=bool)
def ARNNDense( num_layers:int=3, features:int=36, local_dimension:int=2, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None )->None:
Autoregressive neural network with dense layers
num_layers: int features: int local_dimension: int
log_cp: (batch, seq_len, local_dimension)
#check the autoregressivity decoder = Transformer_decoder(num_heads = 1, d_model = 4, num_layers = 1, data_type = 'discrete', local_dimension = 2) x = jnp.ones((100,10)) z = jnp.ones((x.shape[0],5)) inputs = (z, x) key = jax.random.PRNGKey(3124) params_decoder = decoder.init(key, inputs) cp1 = decoder.apply(params_decoder, inputs) x2 = x.at[0,2].set(0) inputs = (z, x2) cp2 = decoder.apply(params_decoder, inputs) cp1[0]==cp2[0] #the first 3 cp should be the same
def Transformer_decoder( num_heads:int=4, d_model:int=32, num_layers:int=3, data_type:str='discrete', local_dimension:int=2, parent:Union=<flax.linen.module._Sentinel object at 0x7f9be9751700>, name:Optional=None )->None:
Decoder based on the transformer architecture
num_heads: int d_model: int num_layers: int data_type: str local_dimension: int