Source code for nlgm.autoencoder

from torch import nn
from nlgm.manifolds import ProductManifold


# ==========================================================================================
# ============================= FOR USE IN `mnist_experiment.ipynb`=========================
# ==========================================================================================


[docs] class Encoder(nn.Module): """ Encoder network for the geometric autoencoder. Parameters ---------- hidden_dim : int Number of hidden channels in intermediate convolution layers. latent_dim : int Size of the latent representation. """ def __init__(self, hidden_dim=20, latent_dim=2): super(Encoder, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, hidden_dim, 3, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, stride=2, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, latent_dim, 3, padding=1), nn.BatchNorm2d(latent_dim), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1), nn.Flatten(), )
[docs] def forward(self, x): """ Encode an input batch. Parameters ---------- x : torch.Tensor Input tensor of images. Returns ------- torch.Tensor Latent representation. """ z = self.encoder(x) return z
[docs] class Decoder(nn.Module): """ Decoder network for the geometric autoencoder. Parameters ---------- hidden_dim : int Number of hidden channels in intermediate layers. latent_dim : int Size of the latent representation. """ def __init__(self, hidden_dim=20, latent_dim=2): super(Decoder, self).__init__() self.decoder = nn.Sequential( nn.Linear(latent_dim, hidden_dim * 7 * 7), nn.ReLU(inplace=True), nn.Unflatten(1, (hidden_dim, 7, 7)), nn.ConvTranspose2d( hidden_dim, hidden_dim, 3, stride=2, padding=1, output_padding=1 ), # Upsample to 14x14 nn.ReLU(inplace=True), nn.BatchNorm2d(hidden_dim), nn.ConvTranspose2d( hidden_dim, hidden_dim, 3, stride=2, padding=1, output_padding=1 ), # Upsample to 28x28 nn.ReLU(inplace=True), nn.BatchNorm2d(hidden_dim), nn.Conv2d(hidden_dim, 1, 3, padding=1), # Reduce to 1 channel nn.Sigmoid(), # Output in range [0, 1] )
[docs] def forward(self, z): """ Decode latent vectors back to image space. Parameters ---------- z : torch.Tensor Latent tensor. Returns ------- torch.Tensor Reconstructed image tensor. """ x_recon = self.decoder(z) return x_recon
[docs] class GeometricAutoencoder(nn.Module): """ Autoencoder with latent projection to a product manifold. Parameters ---------- signature : list Curvature signature for the product manifold. hidden_dim : int Number of hidden channels in intermediate layers. latent_dim : int Size of the latent representation. """ def __init__(self, signature, hidden_dim=20, latent_dim=2): super(GeometricAutoencoder, self).__init__() self.geometry = ProductManifold(signature) self.encoder = Encoder(hidden_dim, latent_dim) self.decoder = Decoder(hidden_dim, latent_dim)
[docs] def forward(self, x): """ Encode, project to manifold, and decode. Parameters ---------- x : torch.Tensor Input image tensor. Returns ------- torch.Tensor Reconstructed image tensor. """ z = self.encoder(x) z = self.geometry.exponential_map(z) x_recon = self.decoder(z) return x_recon