Recently I got quite fascinated by integrating a variational auto-encoder 1 technique - or especially the reparameterization trick - within a larger computational graph in which I was trying to learn embeddings in a first stage and then try to find “blurry directions or regions” within that embeddings to navigate a larger model through an auto-regressive task. What I stumbled upon was that variational auto-encoder were usually used for discrete class targets but when changing the problem to a continuous vector space and the cross entropy to a mean squared error loss while keeping the variational lower bound with the kullback-leibler divergence estimation for the gaussian parameters of the latent space I found that it was not simply working out of the box. Reading 2 I found a reference to the appendix of 1:
C.2 Gaussian MLP as encoder or decoder
In this case let encoder or decoder be a multivariate Gaussian with a diagonal covariance structure: where , , where are the weights and biases of the MLP and part of when used as decoder. Note that when this network is used as an encoder , then and are swapped, and the weights and biases are variational parameters .
and Thomas Viehmann notes in 2 that
So the Gaussian at the reconstruction step has nothing to do (well, except being conditional on the latents) with the Gaussian from the latents (which is the bit where you do the reparametrization and things).
and up to my current understanding for a gaussian MLP as decoder this will boil down in python to
h_mu, h_logvar = self.encode(v1)
latent = self.reparameterize(h_mu, h_logvar)
z_est_mu, z_est_logvar = self.decode(latent)
z_estimate = self.reparameterize(z_est_mu, z_est_logvar)
loss_rec = 0
KLD = 0
if v2 is not None:
loss_rec = self._log2pi + z_est_logvar + (v2 - z_est_mu) ** 2 / (2 * torch.exp(z_est_logvar)) # log prob of reconstruction
KLD = -0.5 * torch.sum(1 + h_logvar - h_mu.pow(2) - h_logvar.exp()) # variational lower bound?
return self._classifier(v1, z_estimate), loss_rec, KLD
in which I combine a classifier in my particular use case into the whole pipeline by trying to estimate expected class targets from two embedding vectors and while the auto-encoder part should learn a latent representation for estimating based on .
class VarAutoEncoder(nn.Module):
def __init__(self, classifier, dim: int):
super(VarAutoEncoder, self).__init__()
self._classifier = classifier
hidden_space = max(int(dim/2), 5)
self._latent_space = max(int(hidden_space/2), 5)
self._enc1 = nn.Linear(in_features=dim, out_features=hidden_space)
self._enc21 = nn.Linear(in_features=hidden_space, out_features=self._latent_space)
self._enc22 = nn.Linear(in_features=hidden_space, out_features=self._latent_space)
self._act = nn.Tanh()
self._dec1 = nn.Linear(in_features=self._latent_space, out_features=hidden_space)
self._dec21 = nn.Linear(in_features=hidden_space, out_features=dim)
self._dec22 = nn.Linear(in_features=hidden_space, out_features=dim)
self._log2pi = torch.log2(torch.Tensor([np.pi]))
def reparameterize(self, mu, log_var):
"""
:param mu: mean from the encoder's latent space [B,d]
:param log_var: log variance from the encoder's latent space [B,d]
:return tensor [B,d]
"""
std = torch.exp(0.5*log_var) # standard deviation
eps = torch.randn_like(std) # `randn_like` as we need the same size
sample = mu + (eps * std) # sampling as if coming from the input space
return sample
def encode(self, x): # x=[B,G]
h1 = F.relu(self._enc1(x))
return self._enc21(h1), self._enc22(h1) # mu=[B,L] / std=[B,L]
def decode(self, x): # x=[B,L]
h1 = F.relu(self._dec1(x))
return self._dec21(h1), self._dec22(h1) # mu=[B,G] / std=[B,G]
def forward(self, v1, v2=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
h_mu, h_logvar = self.encode(v1)
latent = self.reparameterize(h_mu, h_logvar)
z_est_mu, z_est_logvar = self.decode(latent)
z_estimate = self.reparameterize(z_est_mu, z_est_logvar)
loss_rec = 0
KLD = 0
if v2 is not None:
loss_rec = self._log2pi + z_est_logvar + (v2 - z_est_mu) ** 2 / (2 * torch.exp(z_est_logvar)) # log prob of reconstruction
KLD = -0.5 * torch.sum(1 + h_logvar - h_mu.pow(2) - h_logvar.exp()) # variational lower bound?
return self._classifier(v1, z_estimate), loss_rec, KLD
References
@article{kingma2013auto,
title={Auto-encoding variational bayes},
author={Kingma, Diederik P and Welling, Max},
journal={arXiv preprint arXiv:1312.6114},
year={2013}
}