Source code for nlgm.autoencoder

from torch import nn
from nlgm.manifolds import ProductManifold


# ==========================================================================================
# ============================= FOR USE IN `mnist_experiment.ipynb`=========================
# ==========================================================================================


[docs] class Encoder(nn.Module): def __init__(self, hidden_dim=20, latent_dim=2): """ Encoder class for the geometric autoencoder. Args: hidden_dim (int): Number of hidden dimensions. latent_dim (int): Number of latent dimensions. """ super(Encoder, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, hidden_dim, 3, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, stride=2, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, latent_dim, 3, padding=1), nn.BatchNorm2d(latent_dim), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1), nn.Flatten(), )
[docs] def forward(self, x): """ Forward pass of the encoder. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Encoded output tensor. """ z = self.encoder(x) return z
[docs] class Decoder(nn.Module): def __init__(self, hidden_dim=20, latent_dim=2): """ Decoder class for the geometric autoencoder. Args: hidden_dim (int): Number of hidden dimensions. latent_dim (int): Number of latent dimensions. """ super(Decoder, self).__init__() self.decoder = nn.Sequential( nn.Linear(latent_dim, hidden_dim * 7 * 7), nn.ReLU(inplace=True), nn.Unflatten(1, (hidden_dim, 7, 7)), nn.ConvTranspose2d( hidden_dim, hidden_dim, 3, stride=2, padding=1, output_padding=1 ), # Upsample to 14x14 nn.ReLU(inplace=True), nn.BatchNorm2d(hidden_dim), nn.ConvTranspose2d( hidden_dim, hidden_dim, 3, stride=2, padding=1, output_padding=1 ), # Upsample to 28x28 nn.ReLU(inplace=True), nn.BatchNorm2d(hidden_dim), nn.Conv2d(hidden_dim, 1, 3, padding=1), # Reduce to 1 channel nn.Sigmoid(), # Output in range [0, 1] )
[docs] def forward(self, z): """ Forward pass of the decoder. Args: z (torch.Tensor): Encoded input tensor. Returns: torch.Tensor: Decoded output tensor. """ x_recon = self.decoder(z) return x_recon
[docs] class GeometricAutoencoder(nn.Module): def __init__(self, signature, hidden_dim=20, latent_dim=2): """ Geometric Autoencoder class. Args: signature (list): List of signature dimensions. hidden_dim (int): Number of hidden dimensions. latent_dim (int): Number of latent dimensions. """ super(GeometricAutoencoder, self).__init__() self.geometry = ProductManifold(signature) self.encoder = Encoder(hidden_dim, latent_dim) self.decoder = Decoder(hidden_dim, latent_dim)
[docs] def forward(self, x): """ Forward pass of the geometric autoencoder. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Decoded output tensor. """ z = self.encoder(x) z = self.geometry.exponential_map(z) x_recon = self.decoder(z) return x_recon