VAE

This module contains the implementation of the flax VAE model, losses and the trainer wrapper

VAE + losses


source

log_importance_weight_matrix


def log_importance_weight_matrix(
    batch_size:int, dataset_size:Array
)->Array:

Calculates a log importance weight matrix.


source

matrix_log_density_gaussian


def matrix_log_density_gaussian(
    x:Array, mu:Array, logvar:Array
)->Array:

Calculates log density of a Gaussian for all batch pairs.


source

TC_term


def TC_term(
    mean:Array, logvar:Array, z:Array
)->Array:

compute TC term


source

log_density_gaussian


def log_density_gaussian(
    x:Array, mu:Array, logvar:Array
):

Calculates log density of a Gaussian.


source

log_normal_pdf


def log_normal_pdf(
    mu:Array, logvar:Array, sample:Array
):

Call self as a function.


source

kl_standard_normal


def kl_standard_normal(
    mean:Array, logvar:Array
)->Array:

KL divergence between diag Gaussian q(z|x) ~ N(mean, exp(logvar)) and p(z)=N(0,I).


source

gaussian_negloglik


def gaussian_negloglik(
    recon_mean:Array, recon_logvar:Array, targets:Array, eps:float=1e-08
)->Array:

reconstruction loss for continuous data


source

categorical_reconstruction_loss


def categorical_reconstruction_loss(
    log_cp_i:Array, targets:Array
)->Array:

reconstruction loss for discrete data


source

VAEmodel


def VAEmodel(
    encoder:Module, decoder:Module, parent:Union=<flax.linen.module._Sentinel object at 0x7f09c1ac9460>,
    name:Optional=None
)->None:

VAE model, wrapper calling the encoder -> reparam -> decoder

Args:

encoder: nn.Module
decoder: nn.Module

Returns:

mean: (batch, latent_dim)
logvar: (batch, latent_dim)
z: (batch, latent_dim)
cp: (batch, seq_len, local_dimension)

Trainer


source

VAETrainer


def VAETrainer(
    model:Module, dataset:Dataset, optimizer:GradientTransformation=adabelief
):

Training wrapper for a the cpVAE for quantum

Args:

model: nn.Module (VAEmodel)

dataset: Dataset

optimizer: optax.GradientTransformation