from torch import nn
from nlgm.manifolds import ProductManifold
# ==========================================================================================
# ============================= FOR USE IN `mnist_experiment.ipynb`=========================
# ==========================================================================================
[docs]
class Encoder(nn.Module):
"""
Encoder network for the geometric autoencoder.
Parameters
----------
hidden_dim : int
Number of hidden channels in intermediate convolution layers.
latent_dim : int
Size of the latent representation.
"""
def __init__(self, hidden_dim=20, latent_dim=2):
super(Encoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, hidden_dim, 3, padding=1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride=2, padding=1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, latent_dim, 3, padding=1),
nn.BatchNorm2d(latent_dim),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
)
[docs]
def forward(self, x):
"""
Encode an input batch.
Parameters
----------
x : torch.Tensor
Input tensor of images.
Returns
-------
torch.Tensor
Latent representation.
"""
z = self.encoder(x)
return z
[docs]
class Decoder(nn.Module):
"""
Decoder network for the geometric autoencoder.
Parameters
----------
hidden_dim : int
Number of hidden channels in intermediate layers.
latent_dim : int
Size of the latent representation.
"""
def __init__(self, hidden_dim=20, latent_dim=2):
super(Decoder, self).__init__()
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim * 7 * 7),
nn.ReLU(inplace=True),
nn.Unflatten(1, (hidden_dim, 7, 7)),
nn.ConvTranspose2d(
hidden_dim, hidden_dim, 3, stride=2, padding=1, output_padding=1
), # Upsample to 14x14
nn.ReLU(inplace=True),
nn.BatchNorm2d(hidden_dim),
nn.ConvTranspose2d(
hidden_dim, hidden_dim, 3, stride=2, padding=1, output_padding=1
), # Upsample to 28x28
nn.ReLU(inplace=True),
nn.BatchNorm2d(hidden_dim),
nn.Conv2d(hidden_dim, 1, 3, padding=1), # Reduce to 1 channel
nn.Sigmoid(), # Output in range [0, 1]
)
[docs]
def forward(self, z):
"""
Decode latent vectors back to image space.
Parameters
----------
z : torch.Tensor
Latent tensor.
Returns
-------
torch.Tensor
Reconstructed image tensor.
"""
x_recon = self.decoder(z)
return x_recon
[docs]
class GeometricAutoencoder(nn.Module):
"""
Autoencoder with latent projection to a product manifold.
Parameters
----------
signature : list
Curvature signature for the product manifold.
hidden_dim : int
Number of hidden channels in intermediate layers.
latent_dim : int
Size of the latent representation.
"""
def __init__(self, signature, hidden_dim=20, latent_dim=2):
super(GeometricAutoencoder, self).__init__()
self.geometry = ProductManifold(signature)
self.encoder = Encoder(hidden_dim, latent_dim)
self.decoder = Decoder(hidden_dim, latent_dim)
[docs]
def forward(self, x):
"""
Encode, project to manifold, and decode.
Parameters
----------
x : torch.Tensor
Input image tensor.
Returns
-------
torch.Tensor
Reconstructed image tensor.
"""
z = self.encoder(x)
z = self.geometry.exponential_map(z)
x_recon = self.decoder(z)
return x_recon