VAE
Overview
This page documents the core VAE utilities used by qdisc.vae.core.
It presents the main tools behind the module:
VAEmodel: the encoder-decoder VAE wrapper that produces latent variables and decoder outputs.- reconstruction losses for discrete, continuous and hybrid quantum data.
VAETrainer: a training helper that applies the VAE objective, tracks training history, and visualizes latent structure.
VAE losses
This section contains loss functions used by the VAE objective.
TC_term
def TC_term(
mean:Array, logvar:Array, z:Array
)->Array:
compute TC term
log_normal_pdf
def log_normal_pdf(
mu:Array, logvar:Array, sample:Array
):
Calculates log density of a Gaussian.
kl_standard_normal
def kl_standard_normal(
mean:Array, logvar:Array
)->Array:
KL divergence between diag Gaussian q(z|x) ~ N(mean, exp(logvar)) and p(z)=N(0,I).
gaussian_negloglik
def gaussian_negloglik(
recon_mean:Array, recon_logvar:Array, targets:Array, eps:float=1e-08
)->Array:
reconstruction loss for continuous data
categorical_reconstruction_loss
def categorical_reconstruction_loss(
log_cp_i:Array, targets:Array
)->Array:
reconstruction loss for discrete data
VAE model wrapper
VAEmodel composes an encoder and decoder module. The encoder maps inputs to a Gaussian latent distribution (mean, logvar), the model samples latent variables using the reparameterization trick, and the decoder converts these latent variables back to output conditional probabilities.
This wrapper makes it easy to use different encoder/decoder architectures while keeping a standard VAE interface.
VAEmodel
def VAEmodel(
encoder:Module, decoder:Module, parent:Union=<flax.linen.module._Sentinel object at 0x7f35dc502060>,
name:Optional=None
)->None:
VAE model, wrapper calling the encoder -> reparam -> decoder
Args:
encoder: nn.Module
decoder: nn.Module
Returns:
mean: (batch, latent_dim)
logvar: (batch, latent_dim)
z: (batch, latent_dim)
cp: (batch, seq_len, local_dimension)
Trainer
VAETrainer manages optimization, batching, and training history for the VAE.
VAETrainer
def VAETrainer(
model:Module, dataset:Dataset, optimizer:GradientTransformation=adabelief
):
Training wrapper for a the cpVAE for quantum
Args:
model: nn.Module (VAEmodel)
dataset: Dataset
optimizer: optax.GradientTransformation
VAETrainer methods
The following methods are implemented for training and analyzing the VAE:
VAETrainer.train
def train(
num_epochs:int, batch_size:int, key:PRNGKey, learning_rate:float=0.001, beta:float=1.0, gamma:float=0,
alpha:float=0.0, printing_rate:int=1, re_shuffle:bool=True
):
High-level training looop.
Args:
num_epochs: int
batch_size: int
key: jax.random.PRNGKey
learning_rate: float
beta: float KL weight
gamma: float TC weight
alpha: float weight for the two part of the hybrid loss
printing_rate: int how often to print info during training
re_shuffle: bool if we reshuffle the dataset to create batches at each epoch
Returns:
None
VAETrainer.compute_repr2d
def compute_repr2d(
theta_pair:tuple=(1, 0), values_other_thetas:tuple=(), return_latvar:bool=False
):
compute mu, abs(mu), logvar accross the parameter space (2d)
VAETrainer.plot_repr2d
def plot_repr2d(
latvar:Optional=None, theta_pair:tuple=(1, 0), subplot:bool=True
):
plot mu, abs(mu), logvar accross the parameter space (2d)
VAETrainer.plot_training
def plot_training(
num_epochs:int
):
plot the evolution of the reconstruction loss and the logvar during the training
VAETrainer.get_data
def get_data(
)->dict:
get the data for saving
VAETrainer.reconstruct_sample
def reconstruct_sample(
input_samples:Array, key:PRNGKey
)->Array:
reconstruct input samples: get z from the encoder + sample from the encoder given z
VAETrainer.get_cp
def get_cp(
input_samples:Array
)->Array:
get the cp of the input samples