# VAE


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

## VAE + losses

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/vae/core.py#L37"
target="_blank" style="float:right; font-size:smaller">source</a>

### log_importance_weight_matrix

``` python

def log_importance_weight_matrix(
    batch_size:int, dataset_size:Array
)->Array:

```

*Calculates a log importance weight matrix.*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/vae/core.py#L29"
target="_blank" style="float:right; font-size:smaller">source</a>

### matrix_log_density_gaussian

``` python

def matrix_log_density_gaussian(
    x:Array, mu:Array, logvar:Array
)->Array:

```

*Calculates log density of a Gaussian for all batch pairs.*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/vae/core.py#L160"
target="_blank" style="float:right; font-size:smaller">source</a>

### TC_term

``` python

def TC_term(
    mean:Array, logvar:Array, z:Array
)->Array:

```

*compute TC term*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/vae/core.py#L151"
target="_blank" style="float:right; font-size:smaller">source</a>

### log_density_gaussian

``` python

def log_density_gaussian(
    x:Array, mu:Array, logvar:Array
):

```

*Calculates log density of a Gaussian.*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/vae/core.py#L147"
target="_blank" style="float:right; font-size:smaller">source</a>

### log_normal_pdf

``` python

def log_normal_pdf(
    mu:Array, logvar:Array, sample:Array
):

```

*Call self as a function.*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/vae/core.py#L140"
target="_blank" style="float:right; font-size:smaller">source</a>

### kl_standard_normal

``` python

def kl_standard_normal(
    mean:Array, logvar:Array
)->Array:

```

*KL divergence between diag Gaussian q(z|x) ~ N(mean, exp(logvar)) and
p(z)=N(0,I).*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/vae/core.py#L132"
target="_blank" style="float:right; font-size:smaller">source</a>

### gaussian_negloglik

``` python

def gaussian_negloglik(
    recon_mean:Array, recon_logvar:Array, targets:Array, eps:float=1e-08
)->Array:

```

*reconstruction loss for continuous data*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/vae/core.py#L124"
target="_blank" style="float:right; font-size:smaller">source</a>

### categorical_reconstruction_loss

``` python

def categorical_reconstruction_loss(
    log_cp_i:Array, targets:Array
)->Array:

```

*reconstruction loss for discrete data*

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/vae/core.py#L72"
target="_blank" style="float:right; font-size:smaller">source</a>

### VAEmodel

``` python

def VAEmodel(
    encoder:Module, decoder:Module, parent:Union=<flax.linen.module._Sentinel object at 0x7f09c1ac9460>,
    name:Optional=None
)->None:

```

*VAE model, wrapper calling the encoder -\> reparam -\> decoder*

Args:

    encoder: nn.Module
    decoder: nn.Module

Returns:

    mean: (batch, latent_dim)
    logvar: (batch, latent_dim)
    z: (batch, latent_dim)
    cp: (batch, seq_len, local_dimension)

## Trainer

------------------------------------------------------------------------

<a
href="https://github.com/qic-ibk/qdisc/blob/main/qdisc/vae/core.py#L173"
target="_blank" style="float:right; font-size:smaller">source</a>

### VAETrainer

``` python

def VAETrainer(
    model:Module, dataset:Dataset, optimizer:GradientTransformation=adabelief
):

```

*Training wrapper for a the cpVAE for quantum*

Args:

    model: nn.Module (VAEmodel)

    dataset: Dataset

    optimizer: optax.GradientTransformation
