import torch
from typing import List
[docs]
class BasicManifold:
"""
Base manifold class with shared geometry attributes.
Parameters
----------
dimension : int
Dimension of the manifold.
curvature : float
Scalar curvature of the manifold.
base_point : torch.Tensor, optional
Origin point of the tangent space. If omitted, a zero vector of
length ``dimension`` is used.
"""
def __init__(
self, dimension: int, curvature: float, base_point: torch.Tensor = None
):
"""
Initialize a manifold descriptor.
Parameters
----------
dimension : int
Dimension of the manifold.
curvature : float
Scalar curvature of the manifold.
base_point : torch.Tensor, optional
Origin point of the tangent space.
"""
self.dimension = dimension
self.curvature = torch.tensor(curvature)
self.base_point = (
base_point if base_point is not None else torch.zeros(dimension)
)
[docs]
def exponential_map(self, tangent_vector: torch.Tensor) -> torch.Tensor:
"""
Map a tangent-space vector onto the manifold.
Parameters
----------
tangent_vector : torch.Tensor
Tangent vector to map.
Returns
-------
torch.Tensor
Mapped point on the manifold.
"""
pass
[docs]
def distance(self, point_x: torch.Tensor, point_y: torch.Tensor) -> torch.Tensor:
"""
Compute geodesic distance between two points on the manifold.
Parameters
----------
point_x : torch.Tensor
First point.
point_y : torch.Tensor
Second point.
Returns
-------
torch.Tensor
Geodesic distance between ``point_x`` and ``point_y``.
"""
pass
[docs]
class EuclideanManifold(BasicManifold):
[docs]
def exponential_map(self, tangent_vector: torch.Tensor) -> torch.Tensor:
"""
Map a tangent vector in Euclidean space.
Parameters
----------
tangent_vector : torch.Tensor
Tangent vector to map.
Returns
-------
torch.Tensor
The mapped point, equal to ``tangent_vector`` in Euclidean space.
"""
# For Euclidean, the exponential map is an identity function
return tangent_vector
[docs]
def distance(self, point_x: torch.Tensor, point_y: torch.Tensor) -> torch.Tensor:
"""
Compute Euclidean distance between two points.
Parameters
----------
point_x : torch.Tensor
First point.
point_y : torch.Tensor
Second point.
Returns
-------
torch.Tensor
Euclidean distance between ``point_x`` and ``point_y``.
"""
# Compute the Euclidean distance between point_x and point_y
return torch.norm(point_x - point_y, dim=-1)
[docs]
class SphericalManifold(BasicManifold):
[docs]
def exponential_map(self, tangent_vector: torch.Tensor) -> torch.Tensor:
"""
Map a tangent vector onto a spherical manifold.
Parameters
----------
tangent_vector : torch.Tensor
Tangent vector to map.
Returns
-------
torch.Tensor
Point on the spherical manifold after exponential mapping.
"""
device = torch.get_device(tangent_vector)
# Compute the L2 norm of the tangent vector
# $\sqrt{K_S}|x|$
norm_v = torch.sqrt(torch.abs(self.curvature.to(device))) * torch.norm(
tangent_vector, dim=-1, keepdim=True
)
# Compute the unit direction vector by dividing the tangent vector by its norm
# $\frac{x}{\sqrt{K_S}\|x\|}$
direction = tangent_vector / norm_v.clamp_min(1e-6)
# Apply the exponential map formula for spherical manifold
# $\cos(\sqrt{K_S}\|x\|)x_p + \sin(\sqrt{K_S}\|x\|)\frac{x}{\sqrt{K_S}\|x\|}$
return (
torch.cos(norm_v) * self.base_point.to(device)
+ torch.sin(norm_v) * direction
)
[docs]
def distance(self, point_x: torch.Tensor, point_y: torch.Tensor) -> torch.Tensor:
"""
Compute spherical geodesic distance between two points.
Parameters
----------
point_x : torch.Tensor
First point.
point_y : torch.Tensor
Second point.
Returns
-------
torch.Tensor
Geodesic distance between ``point_x`` and ``point_y``.
"""
device = torch.get_device(point_x)
# Compute the inner product between point_x and point_y
# $(x,y)_2 := \langle\mathbf{x}, \mathbf{y}\rangle$
inner_product = (point_x * point_y).sum(dim=-1)
# Compute the geodesic distance using the arc-cosine of the inner product
# $\arccos(K_S * (x,y)_2) / \sqrt{|K|}$
return torch.acos(
self.curvature.to(device) * inner_product.clamp(-1.0, 1.0)
) / torch.sqrt(self.curvature)
[docs]
class HyperbolicManifold(BasicManifold):
[docs]
def exponential_map(self, tangent_vector: torch.Tensor) -> torch.Tensor:
"""
Map a tangent vector onto a hyperbolic manifold.
Parameters
----------
tangent_vector : torch.Tensor
Tangent vector to map.
Returns
-------
torch.Tensor
Point on the hyperbolic manifold after exponential mapping.
"""
device = torch.get_device(tangent_vector)
# Compute the L2 norm of the tangent vector
# $\sqrt{-K_H}|x|$
norm_v = torch.sqrt(torch.abs(self.curvature.to(device))) * torch.norm(
tangent_vector, dim=-1, keepdim=True
)
# Compute the unit direction vector by dividing the tangent vector by its norm
# $\frac{x}{\sqrt{-K_H}\|x\|}$
direction = tangent_vector / norm_v.clamp_min(1e-6)
# Apply the exponential map formula for hyperbolic manifold
# $\cosh(\sqrt{-K_H}\|x\|)x_p + \sinh(\sqrt{-K_H}\|x\|)\frac{x}{\sqrt{-K_H}\|x\|}$
return (
torch.cosh(norm_v) * self.base_point.to(device)
+ torch.sinh(norm_v) * direction
)
[docs]
def distance(self, point_x: torch.Tensor, point_y: torch.Tensor) -> torch.Tensor:
"""
Compute hyperbolic geodesic distance between two points.
Parameters
----------
point_x : torch.Tensor
First point.
point_y : torch.Tensor
Second point.
Returns
-------
torch.Tensor
Geodesic distance between ``point_x`` and ``point_y``.
"""
device = torch.get_device(point_x)
# Compute the Lorentz inner product between point_x and point_y
# $-(x_p,y_1)_L$, where $(x,y)_L$ denotes the Lorentz inner product
inner_product = -(point_x[..., 0] * point_y[..., 0]) + (
point_x[..., 1:] * point_y[..., 1:]
).sum(dim=-1)
# Compute the geodesic distance using the arc-hyperbolic cosine of the inner product
# $\frac{1}{\sqrt{-K_H}} \mathrm{arccosh}(K_H*(x,y)_L)$
return torch.acosh(
self.curvature.to(device) * inner_product.clamp_min(1.0)
) / torch.sqrt(-self.curvature)
[docs]
class ProductManifold:
"""
Product manifold assembled from fixed-size component manifolds.
Parameters
----------
curvatures : list[float]
Curvatures for each component manifold.
Notes
-----
Each component manifold is assumed to have dimension 2.
"""
def __init__(self, curvatures: List[float]):
"""
Initialize a product manifold from curvature values.
Parameters
----------
curvatures : list[float]
Curvatures for each component manifold.
"""
self.manifolds = []
self.dimensions = [2 for _ in curvatures]
for dimension, curvature in zip(self.dimensions, curvatures):
if curvature == 0:
self.manifolds.append(EuclideanManifold(dimension, curvature))
elif curvature > 0:
self.manifolds.append(SphericalManifold(dimension, curvature))
else: # curvature < 0
self.manifolds.append(HyperbolicManifold(dimension, curvature))
[docs]
def exponential_map(self, latent_vector: torch.Tensor) -> torch.Tensor:
"""
Apply per-component exponential maps and concatenate the result.
Parameters
----------
latent_vector : torch.Tensor
Latent vector in Euclidean space. Its last dimension should match
the sum of component manifold dimensions.
Returns
-------
torch.Tensor
Projection of ``latent_vector`` into the product manifold space.
"""
# Apply mapping projection of component manifold to corresponding segment of latent vector
segments = self._get_segments(latent_vector)
mapped_segments = [
manifold.exponential_map(segment)
for manifold, segment in zip(self.manifolds, segments)
]
# Concatenate the mapped segments along the last dimension to form a single tensor
return torch.cat(mapped_segments, dim=-1)
def _get_segments(self, tangent_vector):
"""
Split a vector into per-component manifold segments.
Parameters
----------
tangent_vector : torch.Tensor
Vector to segment.
Returns
-------
list[torch.Tensor]
Segments matching ``self.dimensions``.
"""
segments = []
start = 0
for dim in self.dimensions:
end = start + dim
segment = tangent_vector[..., start:end] # Supports batch operations
segments.append(segment)
start = end
return segments
[docs]
def distance(self, point_x: torch.Tensor, point_y: torch.Tensor) -> torch.Tensor:
"""
Compute product-manifold distance between two points.
Parameters
----------
point_x : torch.Tensor
First point in product-manifold coordinates.
point_y : torch.Tensor
Second point in product-manifold coordinates.
Returns
-------
torch.Tensor
Distance between ``point_x`` and ``point_y``.
"""
x_segments = self._get_segments(point_x)
y_segments = self._get_segments(point_y)
# Compute the squared distances for each component manifold
squared_distances = [
manifold.distance(x_segment, y_segment) ** 2
for manifold, x_segment, y_segment in zip(
self.manifolds, x_segments, y_segments
)
]
# Sum the squared distances and take the square root
return torch.sqrt(torch.stack(squared_distances, dim=-1).sum(dim=-1))