Variational Auto-Encoder with Gaussian Decoder

Recently I got quite fascinated by integrating a variational auto-encoder 1 technique - or especially the reparameterization trick - within a larger computational graph in which I was trying to learn embeddings in a first stage and then try to find “blurry directions or regions” within that embeddings to navigate a larger model through an auto-regressive task. What I stumbled upon was that variational auto-encoder were usually used for discrete class targets but when changing the problem to a continuous vector space and the cross entropy to a mean squared error loss while keeping the variational lower bound with the kullback-leibler divergence estimation for the gaussian parameters of the latent space I found that it was not simply working out of the box. Reading 2 I found a reference to the appendix of 1:

C.2 Gaussian MLP as encoder or decoder

In this case let encoder or decoder be a multivariate Gaussian with a diagonal covariance structure: log p(xz)=log N(x;μ,σ2I)log~p(x|z) = log~\mathcal{N}(x; \mu, \sigma^2I) where μ=W4h+b4\mu = W_4h+b_4, log σ2=W5h+b5log~\sigma^2 = W_5h+b_5, h=tanh(W3z+b3)h = tanh(W_3z+b_3) where W3,W4,W5,b3,b4,b5{W_3, W_4, W_5, b_3, b_4, b_5} are the weights and biases of the MLP and part of θ\theta when used as decoder. Note that when this network is used as an encoder qϕ(zx)q_{\phi}(z|x), then zz and xx are swapped, and the weights and biases are variational parameters ϕ\phi.

and Thomas Viehmann notes in 2 that

So the Gaussian at the reconstruction step has nothing to do (well, except being conditional on the latents) with the Gaussian from the latents (which is the bit where you do the reparametrization and things).

and up to my current understanding for a gaussian MLP as decoder this will boil down in python to

h_mu, h_logvar = self.encode(v1)
latent = self.reparameterize(h_mu, h_logvar)
z_est_mu, z_est_logvar = self.decode(latent)
z_estimate = self.reparameterize(z_est_mu, z_est_logvar)

loss_rec = 0
KLD = 0
if v2 is not None:
    loss_rec = self._log2pi + z_est_logvar + (v2 - z_est_mu) ** 2 / (2 * torch.exp(z_est_logvar))  # log prob of reconstruction
    KLD = -0.5 * torch.sum(1 + h_logvar - h_mu.pow(2) - h_logvar.exp())  # variational lower bound?

return self._classifier(v1, z_estimate), loss_rec, KLD

in which I combine a classifier in my particular use case into the whole pipeline by trying to estimate expected class targets from two embedding vectors v1v_1 and v2v_2 while the auto-encoder part should learn a latent representation for estimating v2v_2 based on v1v_1.


class VarAutoEncoder(nn.Module):
    def __init__(self, classifier, dim: int):
        super(VarAutoEncoder, self).__init__()
		self._classifier = classifier

        hidden_space = max(int(dim/2), 5)
        self._latent_space = max(int(hidden_space/2), 5)
        self._enc1 = nn.Linear(in_features=dim, out_features=hidden_space)
        self._enc21 = nn.Linear(in_features=hidden_space, out_features=self._latent_space)
        self._enc22 = nn.Linear(in_features=hidden_space, out_features=self._latent_space)
        self._act = nn.Tanh()

        self._dec1 = nn.Linear(in_features=self._latent_space, out_features=hidden_space)
        self._dec21 = nn.Linear(in_features=hidden_space, out_features=dim)
        self._dec22 = nn.Linear(in_features=hidden_space, out_features=dim)

        self._log2pi = torch.log2(torch.Tensor([np.pi]))

    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space [B,d]
        :param log_var: log variance from the encoder's latent space [B,d]
        :return tensor [B,d]
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample

    def encode(self, x):  # x=[B,G]
        h1 = F.relu(self._enc1(x))
        return self._enc21(h1), self._enc22(h1)  # mu=[B,L] / std=[B,L]

    def decode(self, x):  # x=[B,L]
        h1 = F.relu(self._dec1(x))
        return self._dec21(h1), self._dec22(h1)  # mu=[B,G] / std=[B,G]

    def forward(self, v1, v2=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        h_mu, h_logvar = self.encode(v1)
        latent = self.reparameterize(h_mu, h_logvar)
        z_est_mu, z_est_logvar = self.decode(latent)
        z_estimate = self.reparameterize(z_est_mu, z_est_logvar)

        loss_rec = 0
        KLD = 0
        if v2 is not None:
            loss_rec = self._log2pi + z_est_logvar + (v2 - z_est_mu) ** 2 / (2 * torch.exp(z_est_logvar))  # log prob of reconstruction
            KLD = -0.5 * torch.sum(1 + h_logvar - h_mu.pow(2) - h_logvar.exp())  # variational lower bound?

        return self._classifier(v1, z_estimate), loss_rec, KLD

References

  1. Auto-encoding variational bayes
@article{kingma2013auto,
  title={Auto-encoding variational bayes},
  author={Kingma, Diederik P and Welling, Max},
  journal={arXiv preprint arXiv:1312.6114},
  year={2013}
}