VAE
This module contains the implementation of the flax VAE model, losses and the trainer wrapper
VAE + losses
log_importance_weight_matrix
def log_importance_weight_matrix(
batch_size:int, dataset_size:Array
)->Array:
Calculates a log importance weight matrix.
matrix_log_density_gaussian
def matrix_log_density_gaussian(
x:Array, mu:Array, logvar:Array
)->Array:
Calculates log density of a Gaussian for all batch pairs.
TC_term
def TC_term(
mean:Array, logvar:Array, z:Array
)->Array:
compute TC term
log_density_gaussian
def log_density_gaussian(
x:Array, mu:Array, logvar:Array
):
Calculates log density of a Gaussian.
log_normal_pdf
def log_normal_pdf(
mu:Array, logvar:Array, sample:Array
):
Call self as a function.
kl_standard_normal
def kl_standard_normal(
mean:Array, logvar:Array
)->Array:
KL divergence between diag Gaussian q(z|x) ~ N(mean, exp(logvar)) and p(z)=N(0,I).
gaussian_negloglik
def gaussian_negloglik(
recon_mean:Array, recon_logvar:Array, targets:Array, eps:float=1e-08
)->Array:
reconstruction loss for continuous data
categorical_reconstruction_loss
def categorical_reconstruction_loss(
log_cp_i:Array, targets:Array
)->Array:
reconstruction loss for discrete data
VAEmodel
def VAEmodel(
encoder:Module, decoder:Module, parent:Union=<flax.linen.module._Sentinel object at 0x7f09c1ac9460>,
name:Optional=None
)->None:
VAE model, wrapper calling the encoder -> reparam -> decoder
Args:
encoder: nn.Module
decoder: nn.Module
Returns:
mean: (batch, latent_dim)
logvar: (batch, latent_dim)
z: (batch, latent_dim)
cp: (batch, seq_len, local_dimension)
Trainer
VAETrainer
def VAETrainer(
model:Module, dataset:Dataset, optimizer:GradientTransformation=adabelief
):
Training wrapper for a the cpVAE for quantum
Args:
model: nn.Module (VAEmodel)
dataset: Dataset
optimizer: optax.GradientTransformation