import torch
from typing import List
[docs]
class BasicManifold:
"""
A base class for manifolds, providing a common interface for dimension, curvature, and base point.
Args:
dimension (int): The dimension of the manifold.
curvature (float): The curvature of the manifold.
base_point (torch.Tensor): The origin point of the tangent space.
"""
def __init__(
self, dimension: int, curvature: float, base_point: torch.Tensor = None
):
"""
Initializes a BasicManifold object.
Args:
dimension (int): The dimension of the manifold.
curvature (float): The curvature of the manifold.
base_point (torch.Tensor, optional): The origin point of the tangent space. Defaults to None.
"""
self.dimension = dimension
self.curvature = torch.tensor(curvature)
self.base_point = (
base_point if base_point is not None else torch.zeros(dimension)
)
[docs]
def exponential_map(self, tangent_vector: torch.Tensor) -> torch.Tensor:
"""
Applies the exponential map to the given tangent vector in Euclidean space.
Args:
tangent_vector (torch.Tensor): The tangent vector to be mapped.
Returns:
torch.Tensor: The result of applying the exponential map to the tangent vector.
"""
pass
[docs]
def distance(self, tangent_vector: torch.Tensor) -> torch.Tensor:
"""
Computes the geodesic distance between two points.
Args:
point_x (torch.Tensor): The first point.
point_y (torch.Tensor): The second point.
Returns:
torch.Tensor: The Euclidean distance between the two points.
"""
pass
[docs]
class EuclideanManifold(BasicManifold):
[docs]
def exponential_map(self, tangent_vector: torch.Tensor) -> torch.Tensor:
"""
Applies the exponential map to the given tangent vector in Euclidean space.
Args:
tangent_vector (torch.Tensor): The tangent vector to be mapped.
Returns:
torch.Tensor: The result of applying the exponential map to the tangent vector.
"""
# For Euclidean, the exponential map is an identity function
return tangent_vector
[docs]
def distance(self, point_x: torch.Tensor, point_y: torch.Tensor) -> torch.Tensor:
"""
Computes the geodesic distance between two points in Euclidean space.
Args:
point_x (torch.Tensor): The first point.
point_y (torch.Tensor): The second point.
Returns:
torch.Tensor: The Euclidean distance between the two points.
"""
# Compute the Euclidean distance between point_x and point_y
return torch.norm(point_x - point_y, dim=-1)
[docs]
class SphericalManifold(BasicManifold):
[docs]
def exponential_map(self, tangent_vector: torch.Tensor) -> torch.Tensor:
"""
Applies the exponential map to the given tangent vector in Spherical space.
Args:
tangent_vector (torch.Tensor): The tangent vector to be mapped.
Returns:
torch.Tensor: The result of applying the exponential map to the tangent vector.
"""
device = torch.get_device(tangent_vector)
# Compute the L2 norm of the tangent vector
# $\sqrt{K_S}|x|$
norm_v = torch.sqrt(torch.abs(self.curvature.to(device))) * torch.norm(
tangent_vector, dim=-1, keepdim=True
)
# Compute the unit direction vector by dividing the tangent vector by its norm
# $\frac{x}{\sqrt{K_S}\|x\|}$
direction = tangent_vector / norm_v.clamp_min(1e-6)
# Apply the exponential map formula for spherical manifold
# $\cos(\sqrt{K_S}\|x\|)x_p + \sin(\sqrt{K_S}\|x\|)\frac{x}{\sqrt{K_S}\|x\|}$
return (
torch.cos(norm_v) * self.base_point.to(device)
+ torch.sin(norm_v) * direction
)
[docs]
def distance(self, point_x: torch.Tensor, point_y: torch.Tensor) -> torch.Tensor:
"""
Computes the geodesic distance between two points in Spherical space.
Args:
point_x (torch.Tensor): The first point.
point_y (torch.Tensor): The second point.
Returns:
torch.Tensor: The geodesic distance between the two points.
"""
device = torch.get_device(point_x)
# Compute the inner product between point_x and point_y
# $(x,y)_2 := \langle\mathbf{x}, \mathbf{y}\rangle$
inner_product = (point_x * point_y).sum(dim=-1)
# Compute the geodesic distance using the arc-cosine of the inner product
# $\arccos(K_S * (x,y)_2) / \sqrt{|K|}$
return torch.acos(
self.curvature.to(device) * inner_product.clamp(-1.0, 1.0)
) / torch.sqrt(self.curvature)
[docs]
class HyperbolicManifold(BasicManifold):
[docs]
def exponential_map(self, tangent_vector: torch.Tensor) -> torch.Tensor:
"""
Applies the exponential map to the given tangent vector in Hyperbolic space.
Args:
tangent_vector (torch.Tensor): The tangent vector to be mapped.
Returns:
torch.Tensor: The result of applying the exponential map to the tangent vector.
"""
device = torch.get_device(tangent_vector)
# Compute the L2 norm of the tangent vector
# $\sqrt{-K_H}|x|$
norm_v = torch.sqrt(torch.abs(self.curvature.to(device))) * torch.norm(
tangent_vector, dim=-1, keepdim=True
)
# Compute the unit direction vector by dividing the tangent vector by its norm
# $\frac{x}{\sqrt{-K_H}\|x\|}$
direction = tangent_vector / norm_v.clamp_min(1e-6)
# Apply the exponential map formula for hyperbolic manifold
# $\cosh(\sqrt{-K_H}\|x\|)x_p + \sinh(\sqrt{-K_H}\|x\|)\frac{x}{\sqrt{-K_H}\|x\|}$
return (
torch.cosh(norm_v) * self.base_point.to(device)
+ torch.sinh(norm_v) * direction
)
[docs]
def distance(self, point_x: torch.Tensor, point_y: torch.Tensor) -> torch.Tensor:
"""
Computes the geodesic distance between two points in Hyperbolic space.
Args:
point_x (torch.Tensor): The first point.
point_y (torch.Tensor): The second point.
Returns:
torch.Tensor: The geodesic distance between the two points.
"""
device = torch.get_device(point_x)
# Compute the Lorentz inner product between point_x and point_y
# $-(x_p,y_1)_L$, where $(x,y)_L$ denotes the Lorentz inner product
inner_product = -(point_x[..., 0] * point_y[..., 0]) + (
point_x[..., 1:] * point_y[..., 1:]
).sum(dim=-1)
# Compute the geodesic distance using the arc-hyperbolic cosine of the inner product
# $\frac{1}{\sqrt{-K_H}} \mathrm{arccosh}(K_H*(x,y)_L)$
return torch.acosh(
self.curvature.to(device) * inner_product.clamp_min(1.0)
) / torch.sqrt(-self.curvature)
[docs]
class ProductManifold:
"""
Represents a product manifold constructed from multiple manifold components,
each characterized by its dimension and curvature.
Args:
curvatures (List[float]): A list containing the curvatures of component manifolds.
Attributes:
manifolds (List[BasicManifold]): A list of manifold objects representing the components of the product manifold.
dimensions (List[int]): A list of dimensions of each component manifold.
Note: dimension of each component manifold is assumed to be 2.
"""
def __init__(self, curvatures: List[float]):
"""
Initializes a ProductManifold object.
Args:
curvatures (List[float]): A list containing the curvatures of component manifolds.
"""
self.manifolds = []
self.dimensions = [2 for _ in curvatures]
for dimension, curvature in zip(self.dimensions, curvatures):
if curvature == 0:
self.manifolds.append(EuclideanManifold(dimension, curvature))
elif curvature > 0:
self.manifolds.append(SphericalManifold(dimension, curvature))
else: # curvature < 0
self.manifolds.append(HyperbolicManifold(dimension, curvature))
[docs]
def exponential_map(self, latent_vector: torch.Tensor) -> torch.Tensor:
"""
Applies the exponential map of each component manifold to the corresponding segment of the input latent vector
and returns a concatenated tensor representing the projection into the product manifold space.
Args:
latent_vector (torch.Tensor): A latent vector in Euclidean space to be mapped to the product manifold space.
Its dimension should match the sum of the dimensions of the component manifolds.
Returns:
torch.Tensor: A tensor representing the projection of the input latent vector into the product manifold space,
preserving the differentiability for gradient-based optimization.
"""
# Apply mapping projection of component manifold to corresponding segment of latent vector
segments = self._get_segments(latent_vector)
mapped_segments = [
manifold.exponential_map(segment)
for manifold, segment in zip(self.manifolds, segments)
]
# Concatenate the mapped segments along the last dimension to form a single tensor
return torch.cat(mapped_segments, dim=-1)
def _get_segments(self, tangent_vector):
"""
Aligns component manifolds with corresponding segments of input tangent vector.
Args:
tangent_vector (torch.Tensor): The input tangent vector.
Returns:
List[torch.Tensor]: A list of tensor segments, each corresponding to a component manifold.
"""
segments = []
start = 0
for dim in self.dimensions:
end = start + dim
segment = tangent_vector[..., start:end] # Supports batch operations
segments.append(segment)
start = end
return segments
[docs]
def distance(self, point_x: torch.Tensor, point_y: torch.Tensor) -> torch.Tensor:
"""
Computes the distance between two points in the product manifold space.
Args:
point_x (torch.Tensor): The first point in the product manifold space.
point_y (torch.Tensor): The second point in the product manifold space.
Returns:
torch.Tensor: The distance between the two points in the product manifold space.
"""
x_segments = self._get_segments(point_x)
y_segments = self._get_segments(point_y)
# Compute the squared distances for each component manifold
squared_distances = [
manifold.distance(x_segment, y_segment) ** 2
for manifold, x_segment, y_segment in zip(
self.manifolds, x_segments, y_segments
)
]
# Sum the squared distances and take the square root
return torch.sqrt(torch.stack(squared_distances, dim=-1).sum(dim=-1))