Skip to content

Commit

Permalink
cleanup + comments
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Dec 1, 2023
1 parent 9f5ea5d commit ae9cd3e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 119 deletions.
18 changes: 11 additions & 7 deletions RVGP/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
compute_laplacian,
compute_connection_laplacian,
compute_spectrum,
project_to_local_frame,
express_in_local_frame,
project_to_manifold,
manifold_dimension
)
Expand All @@ -30,7 +30,8 @@ def __init__(self,
n_eigenpairs=None):

print('Fit graph')
G = manifold_graph(vertices,n_neighbors=n_neighbors)
G = manifold_graph(vertices, n_neighbors=n_neighbors)

print('Fit tangent spaces')
gauges, Sigma = tangent_frames(vertices, G, vertices.shape[1], n_neighbors*frac_geodesic_neighbours)

Expand Down Expand Up @@ -93,10 +94,13 @@ def __init__(self,
def random_vector_field(self, seed=0):
"""Generate random vector field over manifold"""

np.random.seed(0)
np.random.seed(seed)

vectors = np.random.uniform(size=(len(self.vertices), 3))-.5
vectors = project_to_manifold(vectors, self.gauges[...,:2])
vectors = np.random.uniform(size=(len(self.vertices),
self.vertices.shape[1])
)
vectors -= .5
vectors = project_to_manifold(vectors, self.gauges[...,:self.dim_man])
vectors /= np.linalg.norm(vectors, axis=1, keepdims=True)

self.vectors = vectors
Expand All @@ -106,9 +110,9 @@ def smooth_vector_field(self, t=100):
if hasattr(self, 'vectors'):

"""Smooth vector field over manifold"""
vectors = project_to_local_frame(self.vectors, self.gauges)
vectors = express_in_local_frame(self.vectors, self.gauges)
vectors = vector_diffusion(vectors, t, L=self.L, Lc=self.Lc, method="matrix_exp")
vectors = project_to_local_frame(vectors, self.gauges, reverse=True)
vectors = express_in_local_frame(vectors, self.gauges, reverse=True)

self.vectors = vectors
else:
Expand Down
137 changes: 26 additions & 111 deletions RVGP/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,20 @@

from sklearn.metrics import pairwise_distances
from sklearn.neighbors import kneighbors_graph
from sklearn.neighbors import KDTree



def compute_laplacian(G, normalization=False):

if normalization:
laplacian = sparse.csr_matrix(nx.normalized_laplacian_matrix(G), dtype=np.float64)
else:
laplacian = sparse.csr_matrix(nx.laplacian_matrix(G), dtype=np.float64)

return laplacian


def compute_connection_laplacian(G, R, normalization=None):
r"""Connection Laplacian
"""Connection Laplacian
Args:
data: Pytorch geometric data object.
R (nxnxdxd): Connection matrices between all pairs of nodes. Default is None,
in case of a global coordinate system.
normalization: None, 'sym', 'rw'
normalization: None, 'rw'
1. None: No normalization
:math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
2. "sym"`: Symmetric normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}`
3. "rw"`: Random-walk normalization
2. "rw"`: Random-walk normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
Returns:
Expand All @@ -65,12 +49,20 @@ def compute_connection_laplacian(G, R, normalization=None):
deg_inv = deg_inv.repeat(dim, axis=0)
Lc = sparse.diags(deg_inv, 0, format='csr') @ Lc

elif normalization == "sym":
raise NotImplementedError

return Lc


def compute_laplacian(G, normalization=False):
"""Laplacian. Used as a helper function to compute_connection_laplacian()"""

if normalization:
laplacian = sparse.csr_matrix(nx.normalized_laplacian_matrix(G), dtype=np.float64)
else:
laplacian = sparse.csr_matrix(nx.laplacian_matrix(G), dtype=np.float64)

return laplacian


def compute_spectrum(laplacian, n_eigenpairs=None, dtype=tf.float64):

if n_eigenpairs is None:
Expand All @@ -88,32 +80,6 @@ def compute_spectrum(laplacian, n_eigenpairs=None, dtype=tf.float64):
return evals, evecs


def sample_from_convex_hull(points, num_samples, k=5):

tree = scipy.spatial.KDTree(points)

if num_samples > len(points):
num_samples = len(points)

sample_points = np.random.choice(len(points), size=num_samples, replace=False)
sample_points = points[sample_points]

# Generate samples
samples = []
for current_point in sample_points:
_, nn_ind = tree.query(current_point, k=k, p=2)
nn_hull = points[nn_ind]

barycentric_coords = np.random.uniform(size=nn_hull.shape[0])
barycentric_coords /= np.sum(barycentric_coords)

current_point = np.sum(nn_hull.T * barycentric_coords, axis=1)

samples.append(current_point)

return np.array(samples)


def manifold_dimension(Sigma, frac_explained=0.9):
"""Estimate manifold dimension based on singular vectors"""

Expand All @@ -134,7 +100,14 @@ def manifold_dimension(Sigma, frac_explained=0.9):
def manifold_graph(X, typ = 'knn', n_neighbors=5):
"""Fit graph over a pointset X"""
if typ == 'knn':
A = kneighbors_graph(X, n_neighbors, mode='connectivity', metric='minkowski', p=2, metric_params=None, include_self=False, n_jobs=None)
A = kneighbors_graph(X,
n_neighbors,
mode='connectivity',
metric='minkowski',
p=2,
metric_params=None,
include_self=False,
n_jobs=None)
A += sparse.eye(A.shape[0])
G = nx.from_scipy_sparse_array(A)

Expand All @@ -150,53 +123,6 @@ def manifold_graph(X, typ = 'knn', n_neighbors=5):
return G


def find_nn(x_query, X, nn=3, r=None):
"""
Find nearest neighbors of a point on the manifold
Parameters
----------
ind_query : 2d np array, list[2d np array]
Index of points whose neighbors are needed.
x : nxd array (dimensions are columns!)
Coordinates of n points on a manifold in d-dimensional space.
nn : int, optional
Number of nearest neighbors. The default is 1.
Returns
-------
dist : list[list]
Distance of nearest neighbors.
ind : list[list]
Index of nearest neighbors.
"""

#Fit neighbor estimator object
kdt = KDTree(X, leaf_size=30, metric='euclidean')

if r is not None:
ind, dist = kdt.query_radius(x_query, r=r, return_distance=True, sort_results=True)
ind = ind[0]
dist = dist[0]
else:
# apparently, the outputs are reversed here compared to query_radius()
dist, ind = kdt.query(x_query, k=nn)

return dist, ind.flatten()


def closest_manifold_point(x_query, d, nn=3):
dist, ind = find_nn(x_query, d.vertices, nn=nn)
w = 1/(dist.T+0.00001)
w /= w.sum()
positional_encoding = d.evecs_Lc.reshape(d.n, -1)
pe_manifold = (positional_encoding[ind]*w).sum(0, keepdims=True)
x_manifold = d.vertices[ind]

return x_manifold, pe_manifold


def furthest_point_sampling(x, N=None, stop_crit=0.1, start_idx=0):
"""A greedy O(N^2) algorithm to do furthest points sampling
Expand All @@ -217,8 +143,6 @@ def furthest_point_sampling(x, N=None, stop_crit=0.1, start_idx=0):
n = D.shape[0] if N is None else N
diam = D.max()

start_idx = 5

perm = np.zeros(n, dtype=np.int32)
perm[0] = start_idx
lambdas = np.zeros(n)
Expand All @@ -239,23 +163,14 @@ def furthest_point_sampling(x, N=None, stop_crit=0.1, start_idx=0):


def project_to_manifold(x, gauges):
"""Project vectors to local coordinates over manifold"""
coeffs = np.einsum("bij,bi->bj", gauges, x)
return np.einsum("bj,bij->bi", coeffs, gauges)


def project_to_local_frame(x, gauges, reverse=False):
def express_in_local_frame(x, gauges, reverse=False):
"""Express vectors in local coordinates over manifold"""
if reverse:
return np.einsum("bji,bi->bj", gauges, x)
else:
return np.einsum("bij,bi->bj", gauges, x)


def local_to_global(x, gauges):
return np.einsum("bj,bij->bi", x, gauges)


def node_eigencoords(node_ind, evecs_Lc, dim):
r, c = evecs_Lc.shape
evecs_Lc = evecs_Lc.reshape(-1, c*dim)
node_coords = evecs_Lc[node_ind]
return node_coords.reshape(-1, c)
return np.einsum("bij,bi->bj", gauges, x)
1 change: 0 additions & 1 deletion examples/eeg_example/run_eeg_vector_field_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# Load EEG vector field data
# =============================================================================


def find_mat_files(directory):
mat_files = {}
for root, dirs, files in os.walk(directory):
Expand Down

0 comments on commit ae9cd3e

Please sign in to comment.