from torch import nn
from nlgm.manifolds import ProductManifold
# ==========================================================================================
# ============================= FOR USE IN `mnist_experiment.ipynb`=========================
# ==========================================================================================
[docs]
class Encoder(nn.Module):
    def __init__(self, hidden_dim=20, latent_dim=2):
        """
        Encoder class for the geometric autoencoder.
        Args:
            hidden_dim (int): Number of hidden dimensions.
            latent_dim (int): Number of latent dimensions.
        """
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride=2, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, latent_dim, 3, padding=1),
            nn.BatchNorm2d(latent_dim),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )
[docs]
    def forward(self, x):
        """
        Forward pass of the encoder.
        Args:
            x (torch.Tensor): Input tensor.
        Returns:
            torch.Tensor: Encoded output tensor.
        """
        z = self.encoder(x)
        return z 
 
[docs]
class Decoder(nn.Module):
    def __init__(self, hidden_dim=20, latent_dim=2):
        """
        Decoder class for the geometric autoencoder.
        Args:
            hidden_dim (int): Number of hidden dimensions.
            latent_dim (int): Number of latent dimensions.
        """
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim * 7 * 7),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (hidden_dim, 7, 7)),
            nn.ConvTranspose2d(
                hidden_dim, hidden_dim, 3, stride=2, padding=1, output_padding=1
            ),  # Upsample to 14x14
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(hidden_dim),
            nn.ConvTranspose2d(
                hidden_dim, hidden_dim, 3, stride=2, padding=1, output_padding=1
            ),  # Upsample to 28x28
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(hidden_dim),
            nn.Conv2d(hidden_dim, 1, 3, padding=1),  # Reduce to 1 channel
            nn.Sigmoid(),  # Output in range [0, 1]
        )
[docs]
    def forward(self, z):
        """
        Forward pass of the decoder.
        Args:
            z (torch.Tensor): Encoded input tensor.
        Returns:
            torch.Tensor: Decoded output tensor.
        """
        x_recon = self.decoder(z)
        return x_recon 
 
[docs]
class GeometricAutoencoder(nn.Module):
    def __init__(self, signature, hidden_dim=20, latent_dim=2):
        """
        Geometric Autoencoder class.
        Args:
            signature (list): List of signature dimensions.
            hidden_dim (int): Number of hidden dimensions.
            latent_dim (int): Number of latent dimensions.
        """
        super(GeometricAutoencoder, self).__init__()
        self.geometry = ProductManifold(signature)
        self.encoder = Encoder(hidden_dim, latent_dim)
        self.decoder = Decoder(hidden_dim, latent_dim)
[docs]
    def forward(self, x):
        """
        Forward pass of the geometric autoencoder.
        Args:
            x (torch.Tensor): Input tensor.
        Returns:
            torch.Tensor: Decoded output tensor.
        """
        z = self.encoder(x)
        z = self.geometry.exponential_map(z)
        x_recon = self.decoder(z)
        return x_recon