VAE

Overview

This page documents the core VAE utilities used by qdisc.vae.core.

It presents the main tools behind the module:

  • VAEmodel: the encoder-decoder VAE wrapper that produces latent variables and decoder outputs.
  • reconstruction losses for discrete, continuous and hybrid quantum data.
  • VAETrainer: a training helper that applies the VAE objective, tracks training history, and visualizes latent structure.

VAE losses

This section contains loss functions used by the VAE objective.


source

TC_term


def TC_term(
    mean:Array, logvar:Array, z:Array
)->Array:

compute TC term


source

log_normal_pdf


def log_normal_pdf(
    mu:Array, logvar:Array, sample:Array
):

Calculates log density of a Gaussian.


source

kl_standard_normal


def kl_standard_normal(
    mean:Array, logvar:Array
)->Array:

KL divergence between diag Gaussian q(z|x) ~ N(mean, exp(logvar)) and p(z)=N(0,I).


source

gaussian_negloglik


def gaussian_negloglik(
    recon_mean:Array, recon_logvar:Array, targets:Array, eps:float=1e-08
)->Array:

reconstruction loss for continuous data


source

categorical_reconstruction_loss


def categorical_reconstruction_loss(
    log_cp_i:Array, targets:Array
)->Array:

reconstruction loss for discrete data

VAE model wrapper

VAEmodel composes an encoder and decoder module. The encoder maps inputs to a Gaussian latent distribution (mean, logvar), the model samples latent variables using the reparameterization trick, and the decoder converts these latent variables back to output conditional probabilities.

This wrapper makes it easy to use different encoder/decoder architectures while keeping a standard VAE interface.


source

VAEmodel


def VAEmodel(
    encoder:Module, decoder:Module, parent:Union=<flax.linen.module._Sentinel object at 0x7f35dc502060>,
    name:Optional=None
)->None:

VAE model, wrapper calling the encoder -> reparam -> decoder

Args:

encoder: nn.Module
decoder: nn.Module

Returns:

mean: (batch, latent_dim)
logvar: (batch, latent_dim)
z: (batch, latent_dim)
cp: (batch, seq_len, local_dimension)

Trainer

VAETrainer manages optimization, batching, and training history for the VAE.


source

VAETrainer


def VAETrainer(
    model:Module, dataset:Dataset, optimizer:GradientTransformation=adabelief
):

Training wrapper for a the cpVAE for quantum

Args:

model: nn.Module (VAEmodel)

dataset: Dataset

optimizer: optax.GradientTransformation

VAETrainer methods

The following methods are implemented for training and analyzing the VAE:


source

VAETrainer.train


def train(
    num_epochs:int, batch_size:int, key:PRNGKey, learning_rate:float=0.001, beta:float=1.0, gamma:float=0,
    alpha:float=0.0, printing_rate:int=1, re_shuffle:bool=True
):

High-level training looop.

Args:

num_epochs: int
batch_size: int
key: jax.random.PRNGKey
learning_rate: float
beta: float KL weight
gamma: float TC weight
alpha: float weight for the two part of the hybrid loss
printing_rate: int how often to print info during training
re_shuffle: bool if we reshuffle the dataset to create batches at each epoch

Returns:

None

source

VAETrainer.compute_repr2d


def compute_repr2d(
    theta_pair:tuple=(1, 0), values_other_thetas:tuple=(), return_latvar:bool=False
):

compute mu, abs(mu), logvar accross the parameter space (2d)


source

VAETrainer.plot_repr2d


def plot_repr2d(
    latvar:Optional=None, theta_pair:tuple=(1, 0), subplot:bool=True
):

plot mu, abs(mu), logvar accross the parameter space (2d)


source

VAETrainer.plot_training


def plot_training(
    num_epochs:int
):

plot the evolution of the reconstruction loss and the logvar during the training


source

VAETrainer.get_data


def get_data(
    
)->dict:

get the data for saving


source

VAETrainer.reconstruct_sample


def reconstruct_sample(
    input_samples:Array, key:PRNGKey
)->Array:

reconstruct input samples: get z from the encoder + sample from the encoder given z


source

VAETrainer.get_cp


def get_cp(
    input_samples:Array
)->Array:

get the cp of the input samples